Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1688bdd1e9 | |||
| d33d51fa75 | |||
| e3bf065574 | |||
| 3e52144058 | |||
| d5e52d7d00 | |||
| 17e5263a76 | |||
| 8d6d949ec3 | |||
| b5fde8eb6d | |||
| 7eef5defb8 | |||
| bc01e6f539 | |||
| 0462e3dc3f | |||
| 7b20fc011b | |||
| 20738f3623 | |||
| cdea7d16bd | |||
| 5de387dbf9 | |||
| 6f8e7ccb57 | |||
| 4384315b44 | |||
| 6439ab1515 | |||
| f94226122c | |||
| 7493618fdc | |||
| 205efd40a1 | |||
| 14207f8492 | |||
| 4e850c2834 | |||
| 75fced579e | |||
| b73f367f22 | |||
| 8f2137c72b | |||
| 124007cc98 | |||
| eb5bfff0b0 | |||
| 3edb180c08 | |||
| 66d555e625 | |||
| 4f863fd9fc | |||
| 267c030457 | |||
| c19309fe7e | |||
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f | |||
| 37d74efc2d | |||
| 22e098ac8b | |||
| 9864f9f517 | |||
| 53b32f3601 | |||
| 565c44766d | |||
| e6a9e210ba | |||
| d3f329f924 | |||
| 98879b38c1 | |||
| 7b3b0f5eae | |||
| 021ccceef1 | |||
| f03871c50a | |||
| dc00d17abe |
@@ -4,12 +4,19 @@ early_access: false
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: true
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
issue_enrichment:
|
||||
planning:
|
||||
enabled: false
|
||||
|
||||
@@ -10,17 +10,44 @@ on:
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on workflow file changes (without pushing)
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/containers.yml'
|
||||
- 'docker/build-container.sh'
|
||||
- 'docker/*.Containerfile'
|
||||
|
||||
# grant permissions on GITHUB_TOKEN to publish packages
|
||||
# ref: https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#publishing-a-package-using-an-action
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
platform: [intel, cuda, vulkan, cpu, musa, rocm]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
if: matrix.platform == 'rocm'
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -31,7 +58,7 @@ jobs:
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
|
||||
@@ -3,9 +3,25 @@ name: Windows CI
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
@@ -28,7 +44,7 @@ jobs:
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
@@ -43,7 +59,7 @@ jobs:
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
shell: bash
|
||||
|
||||
@@ -3,9 +3,25 @@ name: Linux CI
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -3,13 +3,13 @@ name: goreleaser
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag version to release (e.g. v144)'
|
||||
description: "Tag version to release (e.g. v144)"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
@@ -19,35 +19,30 @@ jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
-
|
||||
name: Set up Go
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
-
|
||||
name: Set up Node.js
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '23'
|
||||
-
|
||||
name: Install dependencies and build UI
|
||||
node-version: "24"
|
||||
- name: Install dependencies and build UI
|
||||
run: |
|
||||
cd ui
|
||||
cd ui-svelte
|
||||
npm ci
|
||||
npm run build
|
||||
|
||||
-
|
||||
name: Run GoReleaser
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
# either 'goreleaser' (default) or 'goreleaser-pro'
|
||||
distribution: goreleaser
|
||||
# 'latest', 'nightly', or a semver
|
||||
version: '~> v2'
|
||||
version: "~> v2"
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -76,4 +71,4 @@ jobs:
|
||||
"release": {
|
||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
name: UI Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ui-svelte
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '24'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: ui-svelte/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Type check
|
||||
run: npm run check
|
||||
|
||||
- name: Run tests
|
||||
run: npm test
|
||||
@@ -0,0 +1,50 @@
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and svelt5 for UI (located in ui/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use the github CLI `gh` to create pull requests and work with github
|
||||
- Rules for creating pull requests:
|
||||
- keep them short and focused on changes.
|
||||
- never include a test plan
|
||||
- write the summary using the same style rules as commit message
|
||||
|
||||
## Testing
|
||||
|
||||
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||
|
||||
### Commit message example format:
|
||||
|
||||
```
|
||||
proxy: add new feature
|
||||
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
- key change 1
|
||||
- key change 2
|
||||
- key change 3
|
||||
|
||||
fixes #123
|
||||
```
|
||||
|
||||
## Code Reviews
|
||||
|
||||
- use three levels High, Medium, Low severity
|
||||
- label each discovered issue with a label like H1, M2, L3 respectively
|
||||
- High severity are must fix issues (security, race conditions, critical bugs)
|
||||
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
|
||||
- Low severity are nice to have changes and nits
|
||||
- Include a suggestion with each discovered item
|
||||
- Limit your code review to three items with the highest priority first
|
||||
- Double check your discovered items and recommended remediations
|
||||
@@ -1,43 +1 @@
|
||||
# Project: llama-swap
|
||||
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and react for UI (ui/)
|
||||
|
||||
## Testing
|
||||
|
||||
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
### Plan Improvements
|
||||
|
||||
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
||||
|
||||
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
||||
|
||||
- Identify any inconsistencies.
|
||||
- Expand plans out to be detailed specification of requirements and changes to be made.
|
||||
- Plans should have at least these sections:
|
||||
- Title - very short, describes changes
|
||||
- Overview: A more detailed summary of goal and outcomes desired
|
||||
- Design Requirements: Detailed descriptions of what needs to be done
|
||||
- Testing Plan: Tests to be implemented
|
||||
- Checklist: A detailed list of changes to be made
|
||||
|
||||
Look for "plan expansion" as explicit instructions to improve a plan.
|
||||
|
||||
### Implementation of plans
|
||||
|
||||
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
||||
|
||||
## General Rules
|
||||
|
||||
- when summarizing changes only include details that require further action (action items)
|
||||
- when there are no action items, just say "Done."
|
||||
@AGENTS.md
|
||||
|
||||
@@ -36,11 +36,11 @@ test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -race -count=1 ./proxy/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui && npm install
|
||||
cd ui-svelte && npm install
|
||||
|
||||
# build react UI
|
||||
ui: ui/node_modules
|
||||
cd ui && npm run build
|
||||
cd ui-svelte && npm run build
|
||||
|
||||
# Build OSX binary
|
||||
mac: ui
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
@@ -13,14 +13,21 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
|
||||
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
|
||||
- future proof, upgrade your inference servers at any time.
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/responses`
|
||||
- `v1/embeddings`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- `v1/audio/voices`
|
||||
- `v1/images/generations`
|
||||
- `v1/images/edits`
|
||||
- ✅ Anthropic API supported endpoints:
|
||||
- `v1/messages`
|
||||
- `v1/messages/count_tokens`
|
||||
- ✅ llama-server (llama.cpp) supported endpoints
|
||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||
- `/infill` - for code infilling
|
||||
@@ -32,6 +39,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- `/log` - remote log monitoring
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||
- ✅ Customizable
|
||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- Automatic unloading of models after timeout by setting a `ttl`
|
||||
@@ -44,7 +52,6 @@ llama-swap includes a real time web interface for monitoring logs and controllin
|
||||
|
||||
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
|
||||
|
||||
|
||||
The Activity Page shows recent requests:
|
||||
|
||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||
@@ -61,7 +68,8 @@ llama-swap can be installed in multiple ways
|
||||
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc).
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
|
||||
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
|
||||
|
||||
```shell
|
||||
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
@@ -71,6 +79,14 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
|
||||
# configuration hot reload supported with a
|
||||
# directory volume mount
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
-v /path/to/config:/config \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
|
||||
```
|
||||
|
||||
<details>
|
||||
@@ -89,6 +105,9 @@ docker pull ghcr.io/mostlygeek/llama-swap:musa
|
||||
# tagged llama-swap, platform and llama-server version images
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
||||
|
||||
# non-root cuda
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -191,23 +210,26 @@ As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. H
|
||||
|
||||
## Monitoring Logs on the CLI
|
||||
|
||||
```shell
|
||||
```sh
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
$ curl http://host/logs
|
||||
|
||||
# streams combined logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
curl -Ns http://host/logs/stream
|
||||
|
||||
# just llama-swap's logs
|
||||
curl -Ns 'http://host/logs/stream/proxy'
|
||||
# stream llama-swap's proxy status logs
|
||||
curl -Ns http://host/logs/stream/proxy
|
||||
|
||||
# just upstream's logs
|
||||
curl -Ns 'http://host/logs/stream/upstream'
|
||||
# stream logs from upstream processes that llama-swap loads
|
||||
curl -Ns http://host/logs/stream/upstream
|
||||
|
||||
# stream logs only from a specific model
|
||||
curl -Ns http://host/logs/stream/{model_id}
|
||||
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
# skips history and just streams new log entries
|
||||
# appending ?no-history will disable sending buffered history first
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# Replace ring.Ring with Efficient Circular Byte Buffer
|
||||
|
||||
## Overview
|
||||
|
||||
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
|
||||
|
||||
## Current Issues
|
||||
|
||||
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
|
||||
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
|
||||
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
|
||||
4. Linked list structure has poor cache locality and pointer overhead
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### New CircularBuffer Type
|
||||
|
||||
Create a simple circular byte buffer with:
|
||||
- Single pre-allocated `[]byte` of fixed capacity (10KB)
|
||||
- `head` and `size` integers to track write position and data length
|
||||
- No per-write allocations
|
||||
|
||||
### API Requirements
|
||||
|
||||
The new buffer must support:
|
||||
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
|
||||
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
|
||||
|
||||
### Implementation Details
|
||||
|
||||
```go
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
```
|
||||
|
||||
**Write logic:**
|
||||
- If `len(p) >= capacity`: just keep the last `capacity` bytes
|
||||
- Otherwise: write bytes at `head`, wrapping around if needed
|
||||
- Update `head` and `size` accordingly
|
||||
- Data is copied into the internal buffer (not stored by reference)
|
||||
|
||||
**GetHistory logic:**
|
||||
- Calculate start position: `(head - size + cap) % cap`
|
||||
- If not wrapped: single slice copy
|
||||
- If wrapped: two copies (end of buffer + beginning)
|
||||
- Returns a **new slice** (copy), not a view into internal buffer
|
||||
|
||||
### Immutability Guarantees (must preserve)
|
||||
|
||||
Per existing tests:
|
||||
1. Modifying input `[]byte` after `Write()` must not affect stored data
|
||||
2. `GetHistory()` returns independent copy - modifications don't affect buffer
|
||||
|
||||
## Files to Modify
|
||||
|
||||
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
|
||||
|
||||
## Testing Plan
|
||||
|
||||
Existing tests in `logMonitor_test.go` should continue to pass:
|
||||
- `TestLogMonitor` - Basic write/read and subscriber notification
|
||||
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
|
||||
- `TestWrite_LogTimeFormat` - Timestamp formatting
|
||||
|
||||
Add new tests:
|
||||
- Test buffer wrap-around behavior
|
||||
- Test large writes that exceed buffer capacity
|
||||
- Test exact capacity boundary conditions
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Create `circularBuffer` struct in `logMonitor.go`
|
||||
- [ ] Implement `Write()` method for circular buffer
|
||||
- [ ] Implement `GetHistory()` method for circular buffer
|
||||
- [ ] Update `LogMonitor` struct to use new buffer
|
||||
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
|
||||
- [ ] Update `LogMonitor.Write()` to use new buffer
|
||||
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
|
||||
- [ ] Remove `"container/ring"` import
|
||||
- [ ] Run `make test-dev` to verify existing tests pass
|
||||
- [ ] Add wrap-around test case
|
||||
- [ ] Run `make test-all` for final validation
|
||||
@@ -210,6 +210,11 @@ func main() {
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/v1/audio/voices", func(c *gin.Context) {
|
||||
model := c.Query("model")
|
||||
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
@@ -87,6 +87,12 @@
|
||||
"default": 1000,
|
||||
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
||||
},
|
||||
"captureBuffer": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 5,
|
||||
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
|
||||
},
|
||||
"startPort": {
|
||||
"type": "integer",
|
||||
"default": 5800,
|
||||
@@ -188,11 +194,17 @@
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings. Only stripParams is supported."
|
||||
"description": "Dictionary of filter settings. Supports stripParams and setParams."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
@@ -273,6 +285,78 @@
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||
},
|
||||
"logToStdout": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"proxy",
|
||||
"upstream",
|
||||
"both",
|
||||
"none"
|
||||
],
|
||||
"default": "proxy",
|
||||
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
|
||||
},
|
||||
"apiKeys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
|
||||
},
|
||||
"peers": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"proxy",
|
||||
"models"
|
||||
],
|
||||
"properties": {
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
|
||||
},
|
||||
"apiKey": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"description": "A list of models served by the peer."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
|
||||
}
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,12 +34,27 @@ logLevel: info
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||
# - optional, default: 10
|
||||
# - set to 0 to disable
|
||||
captureBuffer: 15
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
@@ -70,6 +85,9 @@ includeAliasesInList: false
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||
# - env macros are substituted first, before regular macros
|
||||
# - if the env var is not set, config loading will fail with an error
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
@@ -82,6 +100,24 @@ macros:
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# Example of environment variable macros
|
||||
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||
# - useful for paths, secrets, or machine-specific configuration
|
||||
"models_dir": "${env.HOME}/models"
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
|
||||
# use environment variable macros to keep secrets out of the config
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
@@ -165,7 +201,7 @@ models:
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - only stripParams is currently supported
|
||||
# - same capabilities as peer filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
@@ -175,6 +211,16 @@ models:
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for enforcing specific parameter values
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce specific sampling parameters
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
@@ -321,3 +367,56 @@ hooks:
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
# - can be a string or a macro
|
||||
apiKey: ${env.OPENROUTER_API_KEY}
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
# filters: a dictionary of filter settings for peer requests
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as model filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for removing parameters that the peer doesn't support
|
||||
# - the `model` parameter can never be removed
|
||||
stripParams: "temperature, top_p"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for injecting provider-specific settings like data retention policies
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce zero-data-retention for OpenRouter
|
||||
provider:
|
||||
data_collection: "deny"
|
||||
zdr: true
|
||||
|
||||
@@ -1,28 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
# use this to test locally, example:
|
||||
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
|
||||
# you need read:package scope on the token. Generate a personal access token with
|
||||
# the scopes: gist, read:org, repo, write:packages
|
||||
# then: gh auth login (and copy/paste the new token)
|
||||
|
||||
LOG_DEBUG=${LOG_DEBUG:-0}
|
||||
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
|
||||
|
||||
log_debug() {
|
||||
if [ "$LOG_DEBUG" = "1" ]; then
|
||||
echo "[DEBUG] $*"
|
||||
fi
|
||||
}
|
||||
|
||||
log_info() {
|
||||
echo "[INFO] $*"
|
||||
}
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu" "rocm")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
|
||||
log_info "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
||||
# variable, this permits testing with forked llama.cpp repositories
|
||||
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
||||
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
|
||||
|
||||
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
||||
# to enable easy container builds on forked repos
|
||||
@@ -32,25 +54,76 @@ LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
# Fetches the most recent llama.cpp tag matching the given prefix
|
||||
# Handles pagination to search beyond the first 100 results
|
||||
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
|
||||
# Returns: the version number extracted from the tag
|
||||
fetch_llama_tag() {
|
||||
local tag_prefix=$1
|
||||
local page=1
|
||||
local per_page=100
|
||||
|
||||
while true; do
|
||||
log_debug "Fetching page $page for tag prefix: $tag_prefix"
|
||||
|
||||
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
|
||||
|
||||
# Check for API errors
|
||||
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
|
||||
local error_msg=$(echo "$response" | jq -r '.message')
|
||||
log_info "GitHub API error: $error_msg"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Check if response is empty array (no more pages)
|
||||
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
|
||||
log_debug "No more pages (empty response)"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract matching tag from this page
|
||||
local found_tag=$(echo "$response" | jq -r \
|
||||
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
|
||||
| sort -r | head -n1)
|
||||
|
||||
if [ -n "$found_tag" ]; then
|
||||
log_debug "Found tag: $found_tag on page $page"
|
||||
echo "$found_tag" | awk -F '-' '{print $NF}'
|
||||
return 0
|
||||
fi
|
||||
|
||||
page=$((page + 1))
|
||||
|
||||
# Safety limit to prevent infinite loops
|
||||
if [ $page -gt 50 ]; then
|
||||
log_info "Reached pagination safety limit (50 pages)"
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the server tag
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r '.[] | select(.metadata.container.tags[] | startswith("server")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
LCPP_TAG=$(fetch_llama_tag "server")
|
||||
BASE_TAG=server-${LCPP_TAG}
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
|
||||
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||
fi
|
||||
|
||||
SD_TAG=master-${ARCH}
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
log_info "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
else
|
||||
log_info "LCPP_TAG: $LCPP_TAG"
|
||||
fi
|
||||
|
||||
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
|
||||
log_info "Abort: DEBUG_ABORT_BUILD set"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for CONTAINER_TYPE in non-root root; do
|
||||
@@ -68,10 +141,22 @@ for CONTAINER_TYPE in non-root root; do
|
||||
USER_HOME=/app
|
||||
fi
|
||||
|
||||
echo "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||
|
||||
# For architectures with stable-diffusion.cpp support, layer sd-server on top
|
||||
case "$ARCH" in
|
||||
"musa" | "vulkan")
|
||||
log_info "Adding sd-server to $CONTAINER_TAG"
|
||||
docker build --provenance=false -f llama-swap-sd.Containerfile \
|
||||
--build-arg BASE=${CONTAINER_TAG} \
|
||||
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
|
||||
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
|
||||
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
|
||||
esac
|
||||
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
|
||||
@@ -15,4 +15,19 @@ models:
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
--port 9999
|
||||
|
||||
z-image:
|
||||
checkEndpoint: /
|
||||
cmd: |
|
||||
/app/sd-server
|
||||
--listen-port 9999
|
||||
--diffusion-fa
|
||||
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||
--vae /models/ae.safetensors
|
||||
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||
--offload-to-cpu
|
||||
--cfg-scale 1.0
|
||||
--height 512 --width 512
|
||||
--steps 8
|
||||
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
|
||||
@@ -0,0 +1,11 @@
|
||||
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
|
||||
ARG SD_TAG=master-vulkan
|
||||
ARG BASE=llama-swap:latest
|
||||
|
||||
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
|
||||
FROM ${BASE}
|
||||
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
|
||||
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
|
||||
@@ -29,6 +29,10 @@ RUN chown --recursive $UID:$GID $HOME /app
|
||||
USER $UID:$GID
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Add /app to PATH
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
RUN \
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
|
||||
|
Before Width: | Height: | Size: 261 KiB After Width: | Height: | Size: 261 KiB |
|
Before Width: | Height: | Size: 351 KiB After Width: | Height: | Size: 351 KiB |
|
After Width: | Height: | Size: 198 KiB |
@@ -86,9 +86,12 @@ llama-swap supports many more features to customize how you want to manage your
|
||||
## Full Configuration Example
|
||||
|
||||
> [!NOTE]
|
||||
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples.
|
||||
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
|
||||
|
||||
```yaml
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
@@ -114,6 +117,24 @@ healthCheckTimeout: 500
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
@@ -126,6 +147,30 @@ metricsMaxInMemory: 1000
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||
# field
|
||||
# - optional, default: false
|
||||
# - when true, a stream of loading messages will be sent to the client in the
|
||||
# reasoning field so chat UIs can show that loading is in progress.
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
@@ -274,6 +319,10 @@ models:
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
@@ -383,4 +432,36 @@ hooks:
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
apiKey: sk-your-openrouter-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
```
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
## Container Security
|
||||
|
||||
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
|
||||
|
||||
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
|
||||
|
||||
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
|
||||
|
||||
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
const (
|
||||
LogToStdoutProxy = "proxy"
|
||||
LogToStdoutUpstream = "upstream"
|
||||
LogToStdoutBoth = "both"
|
||||
LogToStdoutNone = "none"
|
||||
)
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
@@ -81,6 +87,7 @@ type GroupConfig struct {
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
@@ -114,7 +121,9 @@ type Config struct {
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||
LogToStdout string `yaml:"logToStdout"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
CaptureBuffer int `yaml:"captureBuffer"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
@@ -136,6 +145,12 @@ type Config struct {
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
|
||||
// support API keys, see issue #433, #50, #251
|
||||
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -170,22 +185,30 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// default configuration values
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
@@ -193,6 +216,12 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
@@ -204,55 +233,46 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT, MODEL_ID
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// validate model macros
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Add global macros first
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (can override global)
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -262,23 +282,20 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
@@ -287,18 +304,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Final pass: check if PORT macro is needed after macro expansion
|
||||
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||
// if it is required in either cmd or proxy keys
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort { // either has it
|
||||
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro and substitute it
|
||||
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
@@ -306,10 +319,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
|
||||
// Substitute PORT in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
@@ -319,7 +330,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
@@ -333,35 +344,27 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
continue // replaced at runtime
|
||||
}
|
||||
// Reserved macros are always valid (they should have been substituted already)
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
// Any other macro is unknown
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unknown macros in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the proxy URL.
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf(
|
||||
"model %s: invalid proxy URL: %w", modelId, err,
|
||||
)
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
// if sendLoadingState is nil, set it to the global config value
|
||||
// see #366
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState // copy it
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
@@ -369,18 +372,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
@@ -388,7 +390,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// clean up hooks preload
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
@@ -400,10 +402,56 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -534,20 +582,26 @@ func validateMacro(name string, value any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -555,7 +609,7 @@ func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -614,3 +668,67 @@ func substituteMacroInValue(value any, macroName string, macroValue any) (any, e
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ groups:
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
@@ -214,6 +215,7 @@ groups:
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -761,3 +761,615 @@ models:
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
content: `apiKeys: [""]`,
|
||||
expectedErr: "empty api key found in apiKeys",
|
||||
},
|
||||
{
|
||||
name: "blank spaces only",
|
||||
content: `apiKeys: [" "]`,
|
||||
expectedErr: "api key cannot contain spaces: ` `",
|
||||
},
|
||||
{
|
||||
name: "contains leading space",
|
||||
content: `apiKeys: [" key123"]`,
|
||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||
},
|
||||
{
|
||||
name: "contains trailing space",
|
||||
content: `apiKeys: ["key123 "]`,
|
||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||
},
|
||||
{
|
||||
name: "contains middle space",
|
||||
content: `apiKeys: ["key 123"]`,
|
||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||
},
|
||||
{
|
||||
name: "empty in list with valid keys",
|
||||
content: `apiKeys: ["valid-key", "", "another-key"]`,
|
||||
expectedErr: "empty api key found in apiKeys",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, tt.expectedErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_APIKeys_EnvMacros(t *testing.T) {
|
||||
t.Run("env substitution in apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "secret-key-123")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_API_KEY}"]`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"secret-key-123"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("multiple env substitutions in apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY_1", "key-one")
|
||||
t.Setenv("TEST_API_KEY_2", "key-two")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_API_KEY_1}", "${env.TEST_API_KEY_2}", "static-key"]`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"key-one", "key-two", "static-key"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("missing env var in apiKeys", func(t *testing.T) {
|
||||
content := `apiKeys: ["${env.NONEXISTENT_API_KEY}"]`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
// With string-level env substitution, error only includes var name
|
||||
assert.Contains(t, err.Error(), "NONEXISTENT_API_KEY")
|
||||
})
|
||||
|
||||
t.Run("env substitution results in empty key", func(t *testing.T) {
|
||||
t.Setenv("TEST_EMPTY_KEY", "")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_EMPTY_KEY}"]`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "empty api key found in apiKeys", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_EnvMacros(t *testing.T) {
|
||||
t.Run("basic env substitution in cmd", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_PATH", "/opt/models")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "${env.TEST_MODEL_PATH}/llama-server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/opt/models/llama-server", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env substitution in multiple fields", func(t *testing.T) {
|
||||
t.Setenv("TEST_HOST", "myserver")
|
||||
t.Setenv("TEST_PORT", "9999")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --host ${env.TEST_HOST}"
|
||||
proxy: "http://${env.TEST_HOST}:${env.TEST_PORT}"
|
||||
checkEndpoint: "http://${env.TEST_HOST}/health"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --host myserver", config.Models["test"].Cmd)
|
||||
assert.Equal(t, "http://myserver:9999", config.Models["test"].Proxy)
|
||||
assert.Equal(t, "http://myserver/health", config.Models["test"].CheckEndpoint)
|
||||
})
|
||||
|
||||
t.Run("env in global macro value", func(t *testing.T) {
|
||||
t.Setenv("TEST_BASE_PATH", "/usr/local")
|
||||
|
||||
content := `
|
||||
macros:
|
||||
SERVER_PATH: "${env.TEST_BASE_PATH}/bin/server"
|
||||
models:
|
||||
test:
|
||||
cmd: "${SERVER_PATH} --port 8080"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/usr/local/bin/server --port 8080", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env in model-level macro value", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_DIR", "/models/llama")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
MODEL_FILE: "${env.TEST_MODEL_DIR}/model.gguf"
|
||||
cmd: "server --model ${MODEL_FILE}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --model /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env in metadata", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "secret123")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
metadata:
|
||||
api_key: "${env.TEST_API_KEY}"
|
||||
nested:
|
||||
key: "${env.TEST_API_KEY}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "secret123", config.Models["test"].Metadata["api_key"])
|
||||
nested := config.Models["test"].Metadata["nested"].(map[string]any)
|
||||
assert.Equal(t, "secret123", nested["key"])
|
||||
})
|
||||
|
||||
t.Run("env in filters.stripParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_STRIP_PARAMS", "temperature,top_p")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
filters:
|
||||
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "temperature,top_p", config.Models["test"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("env in cmdStop", func(t *testing.T) {
|
||||
t.Setenv("TEST_KILL_SIGNAL", "SIGTERM")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --port ${PORT}"
|
||||
cmdStop: "kill -${env.TEST_KILL_SIGNAL} ${PID}"
|
||||
proxy: "http://localhost:${PORT}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, config.Models["test"].CmdStop, "-SIGTERM")
|
||||
})
|
||||
|
||||
t.Run("missing env var returns error", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "${env.UNDEFINED_VAR_12345}/server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_VAR_12345")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in global macro", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
PATH: "${env.UNDEFINED_GLOBAL_VAR}"
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_GLOBAL_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in model macro", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
MY_PATH: "${env.UNDEFINED_MODEL_VAR}"
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_MODEL_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in metadata", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
metadata:
|
||||
key: "${env.UNDEFINED_META_VAR}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_META_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env combined with regular macros", func(t *testing.T) {
|
||||
t.Setenv("TEST_ROOT", "/data")
|
||||
|
||||
content := `
|
||||
macros:
|
||||
MODEL_BASE: "${env.TEST_ROOT}/models"
|
||||
models:
|
||||
test:
|
||||
cmd: "server --model ${MODEL_BASE}/${MODEL_ID}.gguf"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --model /data/models/test.gguf", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("multiple env vars in same string", func(t *testing.T) {
|
||||
t.Setenv("TEST_USER", "admin")
|
||||
t.Setenv("TEST_PASS", "secret")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --auth ${env.TEST_USER}:${env.TEST_PASS}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --auth admin:secret", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env value with newline is rejected", func(t *testing.T) {
|
||||
t.Setenv("TEST_MULTILINE", "line1\nline2")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --config ${env.TEST_MULTILINE}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "TEST_MULTILINE")
|
||||
assert.Contains(t, err.Error(), "newlines")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env value with carriage return is rejected", func(t *testing.T) {
|
||||
t.Setenv("TEST_CR", "line1\rline2")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --config ${env.TEST_CR}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "TEST_CR")
|
||||
assert.Contains(t, err.Error(), "newlines")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env value with quotes is escaped for YAML", func(t *testing.T) {
|
||||
t.Setenv("TEST_QUOTED", `value with "quotes"`)
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --arg \"${env.TEST_QUOTED}\""
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
// Quotes are escaped before YAML parsing, then YAML unescapes them
|
||||
// Final result preserves the original value with quotes
|
||||
assert.Contains(t, config.Models["test"].Cmd, `"quotes"`)
|
||||
})
|
||||
|
||||
t.Run("env value with backslash is escaped for YAML", func(t *testing.T) {
|
||||
t.Setenv("TEST_BACKSLASH", `path\to\file`)
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --path \"${env.TEST_BACKSLASH}\""
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
// Backslashes are escaped before YAML parsing, then YAML unescapes them
|
||||
// Final result preserves the original value with backslashes
|
||||
assert.Contains(t, config.Models["test"].Cmd, `path\to\file`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_PeerApiKey_EnvMacros(t *testing.T) {
|
||||
t.Run("env substitution in peer apiKey", func(t *testing.T) {
|
||||
t.Setenv("TEST_PEER_API_KEY", "sk-peer-secret-123")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${env.TEST_PEER_API_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-peer-secret-123", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("missing env var in peer apiKey", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${env.NONEXISTENT_PEER_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
// With string-level env substitution, error only includes var name
|
||||
assert.Contains(t, err.Error(), "NONEXISTENT_PEER_KEY")
|
||||
})
|
||||
|
||||
t.Run("static apiKey unchanged", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-static-key
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-static-key", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("multiple peers with env apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_PEER_KEY_1", "key-one")
|
||||
t.Setenv("TEST_PEER_KEY_2", "key-two")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
peer1:
|
||||
proxy: https://peer1.example.com
|
||||
apiKey: "${env.TEST_PEER_KEY_1}"
|
||||
models:
|
||||
- model-a
|
||||
peer2:
|
||||
proxy: https://peer2.example.com
|
||||
apiKey: "${env.TEST_PEER_KEY_2}"
|
||||
models:
|
||||
- model-b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "key-one", config.Peers["peer1"].ApiKey)
|
||||
assert.Equal(t, "key-two", config.Peers["peer2"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("global macro substitution in peer apiKey", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
API_KEY: sk-from-global-macro
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${API_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-from-global-macro", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("global macro in peer filters.stripParams", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
STRIP_LIST: "temperature, top_p"
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
stripParams: "${STRIP_LIST}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "temperature, top_p", config.Peers["openrouter"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("global macro in peer filters.setParams", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
MAX_TOKENS: 4096
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
max_tokens: "${MAX_TOKENS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4096, config.Peers["openrouter"].Filters.SetParams["max_tokens"])
|
||||
})
|
||||
|
||||
t.Run("env macro in peer filters.setParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_RETENTION_POLICY", "deny")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
data_collection: "${env.TEST_RETENTION_POLICY}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "deny", config.Peers["openrouter"].Filters.SetParams["data_collection"])
|
||||
})
|
||||
|
||||
t.Run("env macro in peer filters.stripParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_STRIP_PARAMS", "frequency_penalty, presence_penalty")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "frequency_penalty, presence_penalty", config.Peers["openrouter"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("unknown macro in peer apiKey fails", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${UNDEFINED_MACRO}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peers.openrouter.apiKey")
|
||||
assert.Contains(t, err.Error(), "unknown macro")
|
||||
})
|
||||
|
||||
t.Run("unknown macro in peer filters.setParams fails", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
value: "${UNDEFINED_MACRO}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peers.openrouter.filters.setParams")
|
||||
assert.Contains(t, err.Error(), "unknown macro")
|
||||
})
|
||||
|
||||
t.Run("env macros in comments are ignored", func(t *testing.T) {
|
||||
content := `
|
||||
# apiKeys:
|
||||
# - "${env.COMMENTED_OUT_KEY_1}"
|
||||
# - "${env.COMMENTED_OUT_KEY_2}"
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
// These env vars are NOT set, but should not cause an error
|
||||
// because they only appear in comment lines
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("env macros in comments ignored while active ones resolve", func(t *testing.T) {
|
||||
t.Setenv("TEST_ACTIVE_KEY", "active-key-value")
|
||||
|
||||
content := `
|
||||
# apiKeys: ["${env.COMMENTED_OUT_KEY}"]
|
||||
apiKeys: ["${env.TEST_ACTIVE_KEY}"]
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"active-key-value"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("env macros in indented comments are ignored", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: |
|
||||
server
|
||||
--port 8080
|
||||
proxy: "http://localhost:8080"
|
||||
# metadata:
|
||||
# api_key: "${env.SOME_UNSET_KEY}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("env macros in inline comments are ignored", func(t *testing.T) {
|
||||
t.Setenv("TEST_INLINE_KEY", "real-value")
|
||||
|
||||
content := `
|
||||
apiKeys: ["${env.TEST_INLINE_KEY}"] # TODO: add ${env.FUTURE_KEY} later
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"real-value"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -158,6 +158,7 @@ groups:
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
@@ -203,6 +204,7 @@ groups:
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||
var ProtectedParams = []string{"model"}
|
||||
|
||||
// Filters contains filter settings for modifying request parameters
|
||||
// Used by both models and peers
|
||||
type Filters struct {
|
||||
// StripParams is a comma-separated list of parameters to remove from requests
|
||||
// The "model" parameter can never be removed
|
||||
StripParams string `yaml:"stripParams"`
|
||||
|
||||
// SetParams is a dictionary of parameters to set/override in requests
|
||||
// Protected params (like "model") cannot be set
|
||||
SetParams map[string]any `yaml:"setParams"`
|
||||
}
|
||||
|
||||
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||
// with duplicates, empty strings, and protected params removed
|
||||
func (f Filters) SanitizedStripParams() []string {
|
||||
if f.StripParams == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
// Skip protected params, empty strings, and duplicates
|
||||
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slices.Sort(cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||
// and keys sorted for consistent iteration order
|
||||
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||
if len(f.SetParams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any, len(f.SetParams))
|
||||
keys := make([]string, 0, len(f.SetParams))
|
||||
|
||||
for key, value := range f.SetParams {
|
||||
// Skip protected params
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Sort keys for consistent ordering
|
||||
sort.Strings(keys)
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, keys
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripParams string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
stripParams: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single param",
|
||||
stripParams: "temperature",
|
||||
want: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
stripParams: "temperature, top_p, top_k",
|
||||
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||
},
|
||||
{
|
||||
name: "model param filtered",
|
||||
stripParams: "model, temperature, top_p",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "only model param",
|
||||
stripParams: "model",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "duplicates removed",
|
||||
stripParams: "temperature, top_p, temperature",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace",
|
||||
stripParams: " temperature , top_p ",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "empty values filtered",
|
||||
stripParams: "temperature,,top_p,",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{StripParams: tt.stripParams}
|
||||
got := f.SanitizedStripParams()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParams map[string]any
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty setParams",
|
||||
setParams: nil,
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
setParams: map[string]any{},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "normal params",
|
||||
setParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected model param filtered",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "complex nested values",
|
||||
setParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantKeys: []string{"provider", "transforms"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParams: tt.setParams}
|
||||
gotParams, gotKeys := f.SanitizedSetParams()
|
||||
|
||||
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||
for i, key := range gotKeys {
|
||||
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||
}
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams, "expected nil params")
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||
for key, wantValue := range tt.wantParams {
|
||||
gotValue, exists := gotParams[key]
|
||||
assert.True(t, exists, "missing key: %s", key)
|
||||
// Simple comparison for basic types
|
||||
switch v := wantValue.(type) {
|
||||
case string, int, float64, bool:
|
||||
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedParams(t *testing.T) {
|
||||
// Verify that "model" is protected
|
||||
assert.Contains(t, ProtectedParams, "model")
|
||||
}
|
||||
@@ -3,8 +3,6 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
@@ -74,16 +72,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||
// See issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"stripParams"`
|
||||
Filters `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
defaults := rawModelFilters{}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
@@ -104,25 +101,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||
// Returns ([]string, error) to match existing API
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
return f.Filters.SanitizedStripParams(), nil
|
||||
}
|
||||
|
||||
@@ -72,3 +72,35 @@ models:
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
stripParams: "top_k"
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelConfig := config.Models["model1"]
|
||||
|
||||
// Check stripParams
|
||||
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||
|
||||
// Check setParams
|
||||
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||
assert.NotNil(t, setParams)
|
||||
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type PeerDictionaryConfig map[string]PeerConfig
|
||||
type PeerConfig struct {
|
||||
Proxy string `yaml:"proxy"`
|
||||
ProxyURL *url.URL `yaml:"-"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Models []string `yaml:"models"`
|
||||
Filters Filters `yaml:"filters"`
|
||||
}
|
||||
|
||||
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawPeerConfig PeerConfig
|
||||
defaults := rawPeerConfig{
|
||||
Proxy: "",
|
||||
ApiKey: "",
|
||||
Models: []string{},
|
||||
Filters: Filters{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate proxy is not empty
|
||||
if defaults.Proxy == "" {
|
||||
return fmt.Errorf("proxy is required")
|
||||
}
|
||||
|
||||
// Validate proxy is a valid URL and store the parsed value
|
||||
parsedURL, err := url.Parse(defaults.Proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||
}
|
||||
defaults.ProxyURL = parsedURL
|
||||
|
||||
// Validate models is not empty
|
||||
if len(defaults.Models) == 0 {
|
||||
return fmt.Errorf("peer models can not be empty")
|
||||
}
|
||||
|
||||
*c = PeerConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
yaml: `
|
||||
proxy: http://192.168.1.23
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid config with apiKey",
|
||||
yaml: `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "missing proxy",
|
||||
yaml: `
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "empty proxy",
|
||||
yaml: `
|
||||
proxy: ""
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "invalid proxy URL",
|
||||
yaml: `
|
||||
proxy: "://invalid"
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "invalid peer proxy URL",
|
||||
},
|
||||
{
|
||||
name: "missing models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
models: []
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||
} else if !contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: http://192.168.1.23:8080/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.ProxyURL == nil {
|
||||
t.Fatal("ProxyURL should not be nil")
|
||||
}
|
||||
|
||||
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Scheme != "http" {
|
||||
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Path != "/api" {
|
||||
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||
}
|
||||
|
||||
func searchSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
provider:
|
||||
data_collection: deny
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
|
||||
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||
}
|
||||
|
||||
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider should be a map")
|
||||
}
|
||||
if provider["data_collection"] != "deny" {
|
||||
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
stripParams: "temperature, top_p"
|
||||
setParams:
|
||||
max_tokens: 1000
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check stripParams
|
||||
stripParams := config.Filters.SanitizedStripParams()
|
||||
if len(stripParams) != 2 {
|
||||
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||
}
|
||||
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||
t.Errorf("unexpected strip params: %v", stripParams)
|
||||
}
|
||||
|
||||
// Check setParams
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
@@ -58,3 +59,11 @@ type ModelPreloadedEvent struct {
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
|
||||
@@ -71,11 +71,15 @@ func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Convert path to forward slashes for cross-platform compatibility
|
||||
// Windows handles forward slashes in paths correctly
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, simpleResponderPath, port, expectedMessage, port)
|
||||
`, cmdPath, port, expectedMessage, port)
|
||||
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,6 +11,85 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
// circularBuffer is a fixed-size circular byte buffer that overwrites
|
||||
// oldest data when full. It provides O(1) writes and O(n) reads.
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
|
||||
func newCircularBuffer(capacity int) *circularBuffer {
|
||||
return &circularBuffer{
|
||||
data: make([]byte, capacity),
|
||||
head: 0,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Write appends bytes to the buffer, overwriting oldest data when full.
|
||||
// Data is copied into the internal buffer (not stored by reference).
|
||||
func (cb *circularBuffer) Write(p []byte) {
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cap := len(cb.data)
|
||||
|
||||
// If input is larger than capacity, only keep the last cap bytes
|
||||
if len(p) >= cap {
|
||||
copy(cb.data, p[len(p)-cap:])
|
||||
cb.head = 0
|
||||
cb.size = cap
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate how much space is available from head to end of buffer
|
||||
firstPart := cap - cb.head
|
||||
if firstPart >= len(p) {
|
||||
// All data fits without wrapping
|
||||
copy(cb.data[cb.head:], p)
|
||||
cb.head = (cb.head + len(p)) % cap
|
||||
} else {
|
||||
// Data wraps around
|
||||
copy(cb.data[cb.head:], p[:firstPart])
|
||||
copy(cb.data[:len(p)-firstPart], p[firstPart:])
|
||||
cb.head = len(p) - firstPart
|
||||
}
|
||||
|
||||
// Update size
|
||||
cb.size += len(p)
|
||||
if cb.size > cap {
|
||||
cb.size = cap
|
||||
}
|
||||
}
|
||||
|
||||
// GetHistory returns all buffered data in correct order (oldest to newest).
|
||||
// Returns a new slice (copy), not a view into internal buffer.
|
||||
func (cb *circularBuffer) GetHistory() []byte {
|
||||
if cb.size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]byte, cb.size)
|
||||
cap := len(cb.data)
|
||||
|
||||
// Calculate start position (oldest data)
|
||||
start := (cb.head - cb.size + cap) % cap
|
||||
|
||||
if start+cb.size <= cap {
|
||||
// Data is contiguous, single copy
|
||||
copy(result, cb.data[start:start+cb.size])
|
||||
} else {
|
||||
// Data wraps around, two copies
|
||||
firstPart := cap - start
|
||||
copy(result[:firstPart], cb.data[start:])
|
||||
copy(result[firstPart:], cb.data[:cb.size-firstPart])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
@@ -19,12 +97,14 @@ const (
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
|
||||
LogBufferSize = 100 * 1024
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
eventbus *event.Dispatcher
|
||||
mu sync.RWMutex
|
||||
buffer *ring.Ring
|
||||
buffer *circularBuffer
|
||||
bufferMu sync.RWMutex
|
||||
|
||||
// typically this can be os.Stdout
|
||||
@@ -45,7 +125,7 @@ func NewLogMonitor() *LogMonitor {
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
buffer: nil, // lazy initialized on first Write
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
@@ -64,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
w.bufferMu.Lock()
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.buffer.Value = bufferCopy
|
||||
w.buffer = w.buffer.Next()
|
||||
if w.buffer == nil {
|
||||
w.buffer = newCircularBuffer(LogBufferSize)
|
||||
}
|
||||
w.buffer.Write(p)
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
// Make a copy for broadcast to preserve immutability
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.broadcast(bufferCopy)
|
||||
return n, nil
|
||||
}
|
||||
@@ -77,16 +160,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
func (w *LogMonitor) GetHistory() []byte {
|
||||
w.bufferMu.RLock()
|
||||
defer w.bufferMu.RUnlock()
|
||||
if w.buffer == nil {
|
||||
return nil
|
||||
}
|
||||
return w.buffer.GetHistory()
|
||||
}
|
||||
|
||||
var history []byte
|
||||
w.buffer.Do(func(p any) {
|
||||
if p != nil {
|
||||
if content, ok := p.([]byte); ok {
|
||||
history = append(history, content...)
|
||||
}
|
||||
}
|
||||
})
|
||||
return history
|
||||
// Clear releases the buffer memory, making it eligible for GC.
|
||||
// The buffer will be lazily re-allocated on the next Write.
|
||||
func (w *LogMonitor) Clear() {
|
||||
w.bufferMu.Lock()
|
||||
w.buffer = nil
|
||||
w.bufferMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||
|
||||
@@ -113,3 +113,204 @@ func TestWrite_LogTimeFormat(t *testing.T) {
|
||||
t.Fatalf("Cannot find timestamp: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_WrapAround(t *testing.T) {
|
||||
// Create a small buffer to test wrap-around
|
||||
cb := newCircularBuffer(10)
|
||||
|
||||
// Write "hello" (5 bytes)
|
||||
cb.Write([]byte("hello"))
|
||||
if got := string(cb.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Write "world" (5 bytes) - buffer now full
|
||||
cb.Write([]byte("world"))
|
||||
if got := string(cb.GetHistory()); got != "helloworld" {
|
||||
t.Errorf("Expected 'helloworld', got %q", got)
|
||||
}
|
||||
|
||||
// Write "12345" (5 bytes) - should overwrite "hello"
|
||||
cb.Write([]byte("12345"))
|
||||
if got := string(cb.GetHistory()); got != "world12345" {
|
||||
t.Errorf("Expected 'world12345', got %q", got)
|
||||
}
|
||||
|
||||
// Write data larger than buffer capacity
|
||||
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
|
||||
if got := string(cb.GetHistory()); got != "ghijklmnop" {
|
||||
t.Errorf("Expected 'ghijklmnop', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
|
||||
// Test empty buffer
|
||||
cb := newCircularBuffer(10)
|
||||
if got := cb.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil for empty buffer, got %q", got)
|
||||
}
|
||||
|
||||
// Test exact capacity
|
||||
cb.Write([]byte("1234567890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
|
||||
// Test write exactly at capacity boundary
|
||||
cb = newCircularBuffer(10)
|
||||
cb.Write([]byte("12345"))
|
||||
cb.Write([]byte("67890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_LazyInit(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Buffer should be nil before any writes
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil before first write")
|
||||
}
|
||||
|
||||
// GetHistory should return nil when buffer is nil
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history before first write, got %q", got)
|
||||
}
|
||||
|
||||
// Write should lazily initialize the buffer
|
||||
lm.Write([]byte("test"))
|
||||
|
||||
if lm.buffer == nil {
|
||||
t.Error("Expected buffer to be initialized after write")
|
||||
}
|
||||
|
||||
if got := string(lm.GetHistory()); got != "test" {
|
||||
t.Errorf("Expected 'test', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_Clear(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write some data
|
||||
lm.Write([]byte("hello"))
|
||||
if got := string(lm.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Clear should release the buffer
|
||||
lm.Clear()
|
||||
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil after Clear")
|
||||
}
|
||||
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history after Clear, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_ClearAndReuse(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write, clear, then write again
|
||||
lm.Write([]byte("first"))
|
||||
lm.Clear()
|
||||
lm.Write([]byte("second"))
|
||||
|
||||
if got := string(lm.GetHistory()); got != "second" {
|
||||
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogMonitorWrite(b *testing.B) {
|
||||
// Test data of varying sizes
|
||||
smallMsg := []byte("small message\n")
|
||||
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
|
||||
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
|
||||
|
||||
b.Run("SmallWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(smallMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("MediumWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LargeWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(largeMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WithSubscribers", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Add some subscribers
|
||||
for i := 0; i < 5; i++ {
|
||||
lm.OnLogData(func(data []byte) {})
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetHistory", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Pre-populate with data
|
||||
for i := 0; i < 1000; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.GetHistory()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
Benchmark Results - MBP M1 Pro
|
||||
|
||||
Before (ring.Ring):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 43 ns | 40 B | 2 |
|
||||
| MediumWrite (241B) | 76 ns | 264 B | 2 |
|
||||
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
|
||||
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
|
||||
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
|
||||
|
||||
After (circularBuffer 10KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 67 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
|
||||
|
||||
After (circularBuffer 100KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|-----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 66 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
|
||||
|
||||
Summary:
|
||||
- GetHistory: 139x faster (10KB), 18x faster (100KB)
|
||||
- Allocations: reduced from 2 to 1 across all operations
|
||||
- Small/medium writes: ~1.1-1.6x faster
|
||||
*/
|
||||
|
||||
@@ -2,6 +2,8 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -26,6 +28,28 @@ type TokenMetrics struct {
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// Size returns the approximate memory usage of this capture in bytes
|
||||
func (c *ReqRespCapture) Size() int {
|
||||
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
|
||||
for k, v := range c.ReqHeaders {
|
||||
size += len(k) + len(v)
|
||||
}
|
||||
for k, v := range c.RespHeaders {
|
||||
size += len(k) + len(v)
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
@@ -44,19 +68,32 @@ type metricsMonitor struct {
|
||||
maxMetrics int
|
||||
nextID int
|
||||
logger *LogMonitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captures map[int]ReqRespCapture // map for O(1) lookup by ID
|
||||
captureOrder []int // track insertion order for FIFO eviction
|
||||
captureSize int // current total size in bytes
|
||||
maxCaptureSize int // max bytes for captures
|
||||
}
|
||||
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
|
||||
mp := &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
return &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
captures: make(map[int]ReqRespCapture),
|
||||
captureOrder: make([]int, 0),
|
||||
captureSize: 0,
|
||||
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
// addMetrics adds a new metric to the collection and publishes an event.
|
||||
// Returns the assigned metric ID.
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
@@ -67,6 +104,49 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||
if !mp.enableCaptures {
|
||||
return
|
||||
}
|
||||
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
captureSize := capture.Size()
|
||||
if captureSize > mp.maxCaptureSize {
|
||||
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest (FIFO) until room available
|
||||
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||
oldestID := mp.captureOrder[0]
|
||||
mp.captureOrder = mp.captureOrder[1:]
|
||||
if evicted, exists := mp.captures[oldestID]; exists {
|
||||
mp.captureSize -= evicted.Size()
|
||||
delete(mp.captures, oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
mp.captures[capture.ID] = capture
|
||||
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||
mp.captureSize += captureSize
|
||||
}
|
||||
|
||||
// getCaptureByID returns a capture by its ID, or nil if not found.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
if capture, exists := mp.captures[id]; exists {
|
||||
return &capture
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics
|
||||
@@ -95,7 +175,35 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
request *http.Request,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||
}
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
reqHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -108,17 +216,35 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize default metrics - these will always be recorded
|
||||
tm := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics skipped, empty body")
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsed
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
@@ -126,19 +252,58 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if tm, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
mp.addMetrics(tm)
|
||||
// extract timings for infill - response is an array, timings are in the last element
|
||||
// see #463
|
||||
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsedMetrics
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
respHeaders := make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: body,
|
||||
}
|
||||
// Only set HasCapture if the capture will actually be stored (not too large)
|
||||
if capture.Size() <= mp.maxCaptureSize {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.addMetrics(tm)
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
mp.addCapture(*capture)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -251,6 +416,25 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
@@ -289,3 +473,43 @@ func (w *responseBodyCopier) Header() http.Header {
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that should be redacted in captures
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding := strings.TrimSpace(strings.Split(part, ";")[0])
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -15,7 +18,7 @@ import (
|
||||
|
||||
func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -34,7 +37,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
mm.addMetrics(TokenMetrics{Model: "model"})
|
||||
@@ -48,7 +51,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("respects max metrics limit", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 3)
|
||||
mm := newMetricsMonitor(testLogger, 3, 0)
|
||||
|
||||
// Add 5 metrics
|
||||
for i := 0; i < 5; i++ {
|
||||
@@ -68,7 +71,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
receivedEvent := make(chan TokenMetricsEvent, 1)
|
||||
cancel := event.On(func(e TokenMetricsEvent) {
|
||||
@@ -98,14 +101,14 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||
t.Run("returns empty slice when no metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
metrics := mm.getMetrics()
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
mm.addMetrics(TokenMetrics{Model: "model1"})
|
||||
mm.addMetrics(TokenMetrics{Model: "model2"})
|
||||
|
||||
@@ -125,7 +128,7 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
jsonData, err := mm.getMetricsJSON()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jsonData)
|
||||
@@ -137,7 +140,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model1",
|
||||
InputTokens: 100,
|
||||
@@ -165,7 +168,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"usage": {
|
||||
@@ -196,7 +199,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("successful request with timings data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"timings": {
|
||||
@@ -236,7 +239,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("streaming request with SSE format", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Note: SSE format requires proper line breaks - each data line followed by blank line
|
||||
responseBody := `data: {"choices":[{"text":"Hello"}]}
|
||||
@@ -272,7 +275,7 @@ data: [DONE]
|
||||
})
|
||||
|
||||
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -291,8 +294,8 @@ data: [DONE]
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("empty response body does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("empty response body records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -307,11 +310,14 @@ data: [DONE]
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -328,11 +334,14 @@ data: [DONE]
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
expectedErr := assert.AnError
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -350,8 +359,8 @@ data: [DONE]
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("response without usage or timings records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"result": "ok"}`
|
||||
|
||||
@@ -367,10 +376,82 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Infill response is an array with timings in the last element
|
||||
responseBody := `[
|
||||
{"content": "first chunk"},
|
||||
{"content": "second chunk"},
|
||||
{"content": "final", "timings": {
|
||||
"prompt_n": 150,
|
||||
"predicted_n": 75,
|
||||
"prompt_per_second": 200.5,
|
||||
"predicted_per_second": 35.5,
|
||||
"prompt_ms": 600.0,
|
||||
"predicted_ms": 1800.0,
|
||||
"cache_n": 30
|
||||
}}
|
||||
]`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/infill", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 150, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 30, metrics[0].CachedTokens)
|
||||
assert.Equal(t, 200.5, metrics[0].PromptPerSecond)
|
||||
assert.Equal(t, 35.5, metrics[0].TokensPerSecond)
|
||||
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
|
||||
})
|
||||
|
||||
t.Run("infill request with empty array records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `[]`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/infill", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -425,7 +506,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 1000)
|
||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
@@ -452,7 +533,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 100)
|
||||
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
@@ -490,7 +571,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
t.Run("prefers timings over usage data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Timings should take precedence over usage
|
||||
responseBody := `{
|
||||
@@ -530,7 +611,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"timings": {
|
||||
@@ -565,7 +646,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
|
||||
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Metrics should be found in the last data line before [DONE]
|
||||
responseBody := `data: {"choices":[{"text":"First"}]}
|
||||
@@ -598,8 +679,8 @@ data: [DONE]
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `data: not json
|
||||
|
||||
@@ -619,14 +700,17 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles empty streaming response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := ``
|
||||
|
||||
@@ -642,17 +726,19 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
// Empty body should not trigger WrapHandler processing
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||
mm := newMetricsMonitor(testLogger, 1000)
|
||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -673,7 +759,7 @@ func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||
|
||||
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||
mm := newMetricsMonitor(testLogger, 100)
|
||||
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -691,3 +777,352 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
mm.addMetrics(metric)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||
t.Run("gzip encoded response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
// Compress with gzip
|
||||
var buf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&buf)
|
||||
gzWriter.Write([]byte(responseBody))
|
||||
gzWriter.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(compressedBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("deflate encoded response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}`
|
||||
|
||||
// Compress with deflate
|
||||
var buf bytes.Buffer
|
||||
flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression)
|
||||
flateWriter.Write([]byte(responseBody))
|
||||
flateWriter.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "deflate")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(compressedBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 200, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Invalid compressed data
|
||||
invalidData := []byte("this is not gzip data")
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(invalidData)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Should not return error, just log warning
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "unknown-encoding")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 300, metrics[0].InputTokens)
|
||||
assert.Equal(t, 100, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReqRespCapture_Size(t *testing.T) {
|
||||
t.Run("calculates size correctly", func(t *testing.T) {
|
||||
capture := ReqRespCapture{
|
||||
ID: 1,
|
||||
ReqPath: "/v1/chat/completions", // 20 bytes
|
||||
ReqHeaders: map[string]string{
|
||||
"Content-Type": "application/json", // 12 + 16 = 28
|
||||
},
|
||||
ReqBody: []byte("request body"), // 12 bytes
|
||||
RespHeaders: map[string]string{
|
||||
"X-Test": "value", // 6 + 5 = 11
|
||||
},
|
||||
RespBody: []byte("response body"), // 13 bytes
|
||||
}
|
||||
|
||||
// Expected: 20 + 12 + 13 + 28 + 11 = 84
|
||||
assert.Equal(t, 84, capture.Size())
|
||||
})
|
||||
|
||||
t.Run("handles empty capture", func(t *testing.T) {
|
||||
capture := ReqRespCapture{}
|
||||
assert.Equal(t, 0, capture.Size())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_AddCapture(t *testing.T) {
|
||||
t.Run("does nothing when captures disabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 0,
|
||||
ReqBody: []byte("test"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
// Should not store capture
|
||||
assert.Nil(t, mm.getCaptureByID(0))
|
||||
})
|
||||
|
||||
t.Run("adds capture when enabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 0,
|
||||
ReqBody: []byte("test request"),
|
||||
RespBody: []byte("test response"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
retrieved := mm.getCaptureByID(0)
|
||||
assert.NotNil(t, retrieved)
|
||||
assert.Equal(t, 0, retrieved.ID)
|
||||
assert.Equal(t, []byte("test request"), retrieved.ReqBody)
|
||||
assert.Equal(t, []byte("test response"), retrieved.RespBody)
|
||||
})
|
||||
|
||||
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
mm.maxCaptureSize = 100 // Set small limit for test
|
||||
|
||||
// Add captures that will exceed the limit
|
||||
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)}
|
||||
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)}
|
||||
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)}
|
||||
|
||||
mm.addCapture(capture1)
|
||||
mm.addCapture(capture2)
|
||||
// Adding capture3 should evict capture1
|
||||
mm.addCapture(capture3)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
|
||||
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
|
||||
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
|
||||
})
|
||||
|
||||
t.Run("skips capture larger than max size", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
mm.maxCaptureSize = 100
|
||||
|
||||
// Add a capture larger than max
|
||||
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)}
|
||||
mm.addCapture(largeCapture)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
||||
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(999))
|
||||
})
|
||||
|
||||
t.Run("returns capture by ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 42,
|
||||
ReqBody: []byte("test"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
retrieved := mm.getCaptureByID(42)
|
||||
assert.NotNil(t, retrieved)
|
||||
assert.Equal(t, 42, retrieved.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
t.Run("redacts sensitive headers", func(t *testing.T) {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer secret-token",
|
||||
"Proxy-Authorization": "Basic creds",
|
||||
"Cookie": "session=abc123",
|
||||
"Set-Cookie": "session=xyz789",
|
||||
"X-Api-Key": "sk-12345",
|
||||
"Content-Type": "application/json",
|
||||
"X-Custom": "safe-value",
|
||||
}
|
||||
|
||||
redactHeaders(headers)
|
||||
|
||||
assert.Equal(t, "[REDACTED]", headers["Authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Proxy-Authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Cookie"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Set-Cookie"])
|
||||
assert.Equal(t, "[REDACTED]", headers["X-Api-Key"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
assert.Equal(t, "safe-value", headers["X-Custom"])
|
||||
})
|
||||
|
||||
t.Run("handles mixed case header names", func(t *testing.T) {
|
||||
headers := map[string]string{
|
||||
"authorization": "Bearer token",
|
||||
"COOKIE": "session=abc",
|
||||
"x-api-key": "key123",
|
||||
}
|
||||
|
||||
redactHeaders(headers)
|
||||
|
||||
assert.Equal(t, "[REDACTED]", headers["authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["COOKIE"])
|
||||
assert.Equal(t, "[REDACTED]", headers["x-api-key"])
|
||||
})
|
||||
|
||||
t.Run("handles empty headers", func(t *testing.T) {
|
||||
headers := map[string]string{}
|
||||
redactHeaders(headers)
|
||||
assert.Empty(t, headers)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
||||
t.Run("captures request and response when enabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
requestBody := `{"model": "test", "prompt": "hello"}`
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Custom", "header-value")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer secret")
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check metric was recorded
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
metricID := metrics[0].ID
|
||||
|
||||
// Check capture was stored with same ID
|
||||
capture := mm.getCaptureByID(metricID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Equal(t, metricID, capture.ID)
|
||||
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||
assert.Equal(t, "/test", capture.ReqPath)
|
||||
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
|
||||
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
|
||||
})
|
||||
|
||||
t.Run("does not capture when disabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
requestBody := `{"model": "test"}`
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Metrics should still be recorded
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
|
||||
// But no capture
|
||||
capture := mm.getCaptureByID(metrics[0].ID)
|
||||
assert.Nil(t, capture)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
// Create a shared transport with reasonable timeouts for peer connections
|
||||
// these can be tuned with feedback later
|
||||
peerTransport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second, // Connection timeout
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 60 * time.Second, // Time to wait for response headers
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
@@ -414,6 +414,9 @@ func (p *Process) stopCommand() {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
|
||||
// free the buffer in processLogger so the memory can be recovered
|
||||
p.processLogger.Clear()
|
||||
}()
|
||||
|
||||
p.cmdMutex.RLock()
|
||||
@@ -646,6 +649,11 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger returns the logger for this process.
|
||||
func (p *Process) Logger() *LogMonitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
@@ -864,7 +872,6 @@ func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
// Add Flush method
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
|
||||
@@ -395,6 +395,10 @@ func TestProcess_StopImmediately(t *testing.T) {
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
|
||||
@@ -46,7 +46,8 @@ func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, u
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
||||
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
@@ -88,6 +89,13 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||
if pg.HasMember(modelName) {
|
||||
return pg.processes[modelName], true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
|
||||
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
@@ -27,6 +28,40 @@ const (
|
||||
|
||||
type proxyCtxKey string
|
||||
|
||||
type InflightCounter struct {
|
||||
mu sync.Mutex
|
||||
total int
|
||||
}
|
||||
|
||||
func newInflightCounter() *InflightCounter {
|
||||
return &InflightCounter{}
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Current() int {
|
||||
ic.mu.Lock()
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Increment() int {
|
||||
ic.mu.Lock()
|
||||
ic.total++
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Decrement() int {
|
||||
ic.mu.Lock()
|
||||
if ic.total > 0 {
|
||||
ic.total--
|
||||
}
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
@@ -42,6 +77,8 @@ type ProxyManager struct {
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
|
||||
inFlightCounter *InflightCounter
|
||||
|
||||
// shutdown signaling
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
@@ -50,19 +87,42 @@ type ProxyManager struct {
|
||||
buildDate string
|
||||
commit string
|
||||
version string
|
||||
|
||||
// peer proxy see: #296, #433
|
||||
peerProxy *PeerProxy
|
||||
}
|
||||
|
||||
func New(config config.Config) *ProxyManager {
|
||||
func New(proxyConfig config.Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
|
||||
if config.LogRequests {
|
||||
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
|
||||
switch proxyConfig.LogToStdout {
|
||||
case config.LogToStdoutNone:
|
||||
muxLogger = NewLogMonitorWriter(io.Discard)
|
||||
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||
case config.LogToStdoutBoth:
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||
case config.LogToStdoutUpstream:
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||
default:
|
||||
// same as config.LogToStdoutProxy
|
||||
// helpful because some old tests create a config.Config directly and it
|
||||
// may not have LogToStdout set explicitly
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||
}
|
||||
|
||||
if proxyConfig.LogRequests {
|
||||
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
||||
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
|
||||
case "debug":
|
||||
proxyLogger.SetLogLevel(LevelDebug)
|
||||
upstreamLogger.SetLogLevel(LevelDebug)
|
||||
@@ -99,7 +159,7 @@ func New(config config.Config) *ProxyManager {
|
||||
"stampnano": time.StampNano,
|
||||
}
|
||||
|
||||
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(config.LogTimeFormat))]; ok {
|
||||
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
|
||||
proxyLogger.SetLogTimeFormat(timeFormat)
|
||||
upstreamLogger.SetLogTimeFormat(timeFormat)
|
||||
}
|
||||
@@ -107,61 +167,78 @@ func New(config config.Config) *ProxyManager {
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
|
||||
var maxMetrics int
|
||||
if config.MetricsMaxInMemory <= 0 {
|
||||
if proxyConfig.MetricsMaxInMemory <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
} else {
|
||||
maxMetrics = config.MetricsMaxInMemory
|
||||
maxMetrics = proxyConfig.MetricsMaxInMemory
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
||||
peerProxy = nil
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
config: proxyConfig,
|
||||
ginEngine: gin.New(),
|
||||
|
||||
proxyLogger: proxyLogger,
|
||||
muxLogger: stdoutLogger,
|
||||
muxLogger: muxLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
|
||||
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer),
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
|
||||
inFlightCounter: newInflightCounter(),
|
||||
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
|
||||
buildDate: "unknown",
|
||||
commit: "abcd1234",
|
||||
version: "0",
|
||||
|
||||
peerProxy: peerProxy,
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
for groupID := range config.Groups {
|
||||
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
||||
for groupID := range proxyConfig.Groups {
|
||||
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
||||
pm.processGroups[groupID] = processGroup
|
||||
}
|
||||
|
||||
pm.setupGinEngine()
|
||||
|
||||
// run any startup hooks
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
|
||||
// do it in the background, don't block startup -- not sure if good idea yet
|
||||
go func() {
|
||||
discardWriter := &DiscardWriter{}
|
||||
for _, realModelName := range config.Hooks.OnStartup.Preload {
|
||||
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
||||
|
||||
if !ok {
|
||||
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
|
||||
if err != nil {
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: false,
|
||||
})
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
|
||||
continue
|
||||
} else {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||
processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: true,
|
||||
})
|
||||
}
|
||||
@@ -236,37 +313,45 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
})
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler)
|
||||
// Protected routes use pm.apiKeyAuth() middleware
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||
pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support anthropic count_tokens API (Also added in the above PR)
|
||||
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// Support embeddings and reranking
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /reranking endpoint + aliases
|
||||
pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /infill endpoint for code infilling
|
||||
pm.ginEngine.POST("/infill", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /completion endpoint
|
||||
pm.ginEngine.POST("/completion", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||
|
||||
// in proxymanager_loghandlers.go
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||
|
||||
/**
|
||||
* User Interface Endpoints
|
||||
@@ -278,9 +363,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
@@ -302,25 +387,35 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
if err != nil {
|
||||
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
||||
} else {
|
||||
// Serve files with compression support under /ui/*
|
||||
// This handler checks for pre-compressed .br and .gz files
|
||||
pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) {
|
||||
filepath := strings.TrimPrefix(c.Param("filepath"), "/")
|
||||
// Default to index.html for directory-like paths
|
||||
if filepath == "" {
|
||||
filepath = "index.html"
|
||||
}
|
||||
|
||||
// serve files that exist under /ui/*
|
||||
pm.ginEngine.StaticFS("/ui", reactFS)
|
||||
ServeCompressedFile(reactFS, c.Writer, c.Request, filepath)
|
||||
})
|
||||
|
||||
// server SPA for UI under /ui/*
|
||||
// Serve SPA for UI under /ui/* - fallback to index.html for client-side routing
|
||||
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
||||
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
file, err := reactFS.Open("index.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
// Check if this looks like a file request (has extension)
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") {
|
||||
// This was likely a file request that wasn't found
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
|
||||
|
||||
// Serve index.html for SPA routing
|
||||
ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -332,6 +427,14 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
gin.DisableConsoleColor()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) trackInflight() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Increment()})
|
||||
defer event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Decrement()})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler interface
|
||||
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
pm.ginEngine.ServeHTTP(w, r)
|
||||
@@ -378,16 +481,10 @@ func (pm *ProxyManager) Shutdown() {
|
||||
pm.shutdownCancel()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
||||
}
|
||||
|
||||
if processGroup.exclusive {
|
||||
@@ -399,54 +496,71 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
||||
}
|
||||
}
|
||||
|
||||
return processGroup, realModelName, nil
|
||||
return processGroup, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := make([]gin.H, 0, len(pm.config.Models))
|
||||
createdTime := time.Now().Unix()
|
||||
|
||||
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
newRecord := func(modelId string) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
data = append(data, newRecord(id))
|
||||
data = append(data, newRecord(id, modelConfig))
|
||||
|
||||
// Include aliases
|
||||
if pm.config.IncludeAliasesInList {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias))
|
||||
data = append(data, newRecord(alias, modelConfig))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
// add peer models
|
||||
for _, modelID := range peer.Models {
|
||||
// Skip unlisted models if not showing them
|
||||
record := newRecord(modelID, config.ModelConfig{
|
||||
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
||||
Metadata: map[string]any{
|
||||
"peerID": peerID,
|
||||
},
|
||||
})
|
||||
|
||||
data = append(data, record)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by the "id" key
|
||||
sort.Slice(data, func(i, j int) bool {
|
||||
si, _ := data[i]["id"].(string)
|
||||
@@ -466,62 +580,61 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
// split the upstream path by / and search for the model name
|
||||
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
|
||||
if len(parts) == 0 {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
modelFound := false
|
||||
// findModelInPath searches for a valid model name in a path with slashes.
|
||||
// It iteratively builds up path segments until it finds a matching model.
|
||||
// Returns: (searchModelName, realModelName, remainingPath, found)
|
||||
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
|
||||
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
|
||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||
searchModelName := ""
|
||||
var modelName, remainingPath string
|
||||
|
||||
for i, part := range parts {
|
||||
if parts[i] == "" {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if searchModelName == "" {
|
||||
searchModelName = part
|
||||
} else {
|
||||
searchModelName = searchModelName + "/" + parts[i]
|
||||
searchModelName = searchModelName + "/" + part
|
||||
}
|
||||
|
||||
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
modelName = real
|
||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||
modelFound = true
|
||||
|
||||
// Check if this is exactly a model name with no additional path
|
||||
// and doesn't end with a trailing slash
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
// Build new URL with query parameters preserved
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Use 308 for non-GET/HEAD requests to preserve method
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
break
|
||||
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||
|
||||
if !modelFound {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
||||
// This ensures relative URLs in upstream responses resolve correctly and
|
||||
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
||||
// HTTP method (301 would downgrade to GET).
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -533,15 +646,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
|
||||
// attempt to record metrics if it is a POST request
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -560,41 +673,90 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
// Look for a matching local model first
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// issue #453 set/override parameters in the JSON body
|
||||
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
|
||||
// issue #453 apply filters for peer requests
|
||||
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||
|
||||
// Apply stripParams - remove specified parameters from request
|
||||
stripParams := peerFilters.SanitizedStripParams()
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Apply setParams - set/override specified parameters in request
|
||||
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
@@ -607,19 +769,19 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -639,9 +801,29 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
// Look for a matching local model first, then check peers
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
var useModelName string
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
useModelName = pm.config.Models[modelID].UseModelName
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -657,8 +839,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
} else {
|
||||
@@ -728,9 +908,46 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
||||
requestedModel := c.Query("model")
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter")
|
||||
return
|
||||
}
|
||||
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
var modelID string
|
||||
|
||||
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
||||
processGroup, err := pm.swapProcessGroup(realModelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
modelID = realModelID
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
modelID = requestedModel
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -745,6 +962,67 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
||||
}
|
||||
}
|
||||
|
||||
// apiKeyAuth returns a middleware that validates API keys if configured.
|
||||
// Returns a pass-through handler if no API keys are configured.
|
||||
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
||||
if len(pm.config.RequiredAPIKeys) == 0 {
|
||||
return func(c *gin.Context) { c.Next() }
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
xApiKey := c.GetHeader("x-api-key")
|
||||
|
||||
var bearerKey string
|
||||
var basicKey string
|
||||
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||
} else if strings.HasPrefix(auth, "Basic ") {
|
||||
// Basic Auth: base64(username:password), password is the API key
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) == 2 {
|
||||
basicKey = parts[1] // password is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use first key found: Basic, then Bearer, then x-api-key
|
||||
var providedKey string
|
||||
if basicKey != "" {
|
||||
providedKey = basicKey
|
||||
} else if bearerKey != "" {
|
||||
providedKey = bearerKey
|
||||
} else {
|
||||
providedKey = xApiKey
|
||||
}
|
||||
|
||||
// Validate key
|
||||
valid := false
|
||||
for _, key := range pm.config.RequiredAPIKeys {
|
||||
if providedKey == key {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Strip auth headers to prevent leakage to upstream
|
||||
c.Request.Header.Del("Authorization")
|
||||
c.Request.Header.Del("x-api-key")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.String(http.StatusOK, "OK")
|
||||
@@ -758,8 +1036,13 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"cmd": process.config.Cmd,
|
||||
"proxy": process.config.Proxy,
|
||||
"ttl": process.config.UnloadAfter,
|
||||
"name": process.config.Name,
|
||||
"description": process.config.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -18,17 +19,20 @@ type Model struct {
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
// Protected with API key authentication
|
||||
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
apiGroup.GET("/version", pm.apiGetVersion)
|
||||
apiGroup.GET("/captures/:id", pm.apiGetCapture)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +86,18 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
||||
})
|
||||
}
|
||||
|
||||
// Iterate over the peer models
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
PeerID: peerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -91,6 +107,7 @@ const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
msgTypeInFlight messageType = "inflight"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
@@ -150,6 +167,18 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
sendInFlight := func(total int) {
|
||||
jsonData, err := json.Marshal(gin.H{"total": total})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
@@ -177,11 +206,19 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
sendMetrics([]TokenMetrics{e.Metrics})
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send in-flight request stats related to token stats "Waiting: N" count.
|
||||
*/
|
||||
defer event.On(func(e InFlightRequestsEvent) {
|
||||
sendInFlight(e.Total)
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||
sendInFlight(pm.inFlightCounter.Current())
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -236,3 +273,20 @@ func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||
"build_date": pm.buildDate,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
|
||||
return
|
||||
}
|
||||
|
||||
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, capture)
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
// prevent nginx from buffering streamed logs
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
@@ -83,18 +83,25 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
|
||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||
var logger *LogMonitor
|
||||
|
||||
if logMonitorId == "" {
|
||||
switch logMonitorId {
|
||||
case "":
|
||||
// maintain the default
|
||||
logger = pm.muxLogger
|
||||
} else if logMonitorId == "proxy" {
|
||||
logger = pm.proxyLogger
|
||||
} else if logMonitorId == "upstream" {
|
||||
logger = pm.upstreamLogger
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
||||
}
|
||||
return pm.muxLogger, nil
|
||||
case "proxy":
|
||||
return pm.proxyLogger, nil
|
||||
case "upstream":
|
||||
return pm.upstreamLogger, nil
|
||||
default:
|
||||
// search for a models specific logger using findModelInPath
|
||||
// to handle model names with slashes (e.g., "author/model")
|
||||
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
|
||||
for _, group := range pm.processGroups {
|
||||
if process, found := group.GetMember(name); found {
|
||||
return process.Logger(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
||||
return r.closeChannel
|
||||
}
|
||||
|
||||
func (r *TestResponseRecorder) closeClient() {
|
||||
r.closeChannel <- true
|
||||
}
|
||||
|
||||
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||
return &TestResponseRecorder{
|
||||
httptest.NewRecorder(),
|
||||
@@ -223,17 +220,23 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||
model2Config.Description = " "
|
||||
|
||||
config := config.Config{
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
Peers: map[string]config.PeerConfig{
|
||||
"peer1": {
|
||||
Proxy: "http://peer1:8080",
|
||||
Models: []string{"peer-model-a", "peer-model-b"},
|
||||
},
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
proxy := New(cfg)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
@@ -258,14 +261,16 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Check the number of models returned
|
||||
assert.Len(t, response.Data, 3)
|
||||
// Check the number of models returned (3 local + 2 peer models)
|
||||
assert.Len(t, response.Data, 5)
|
||||
|
||||
// Check the details of each model
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
"peer-model-a": {},
|
||||
"peer-model-b": {},
|
||||
}
|
||||
|
||||
// make all models
|
||||
@@ -296,6 +301,19 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
description, ok := model["description"].(string)
|
||||
assert.True(t, ok, "description should be a string")
|
||||
assert.Equal(t, "Model 1 description is used for testing", description)
|
||||
} else if modelID == "peer-model-a" || modelID == "peer-model-b" {
|
||||
// Peer models should have meta.llamaswap.peerID
|
||||
meta, exists := model["meta"]
|
||||
assert.True(t, exists, "peer model should have meta field")
|
||||
metaMap, ok := meta.(map[string]interface{})
|
||||
assert.True(t, ok, "meta should be a map")
|
||||
llamaswap, exists := metaMap["llamaswap"]
|
||||
assert.True(t, exists, "meta should have llamaswap field")
|
||||
llamaswapMap, ok := llamaswap.(map[string]interface{})
|
||||
assert.True(t, ok, "llamaswap should be a map")
|
||||
peerID, exists := llamaswapMap["peerID"]
|
||||
assert.True(t, exists, "llamaswap should have peerID field")
|
||||
assert.Equal(t, "peer1", peerID)
|
||||
} else {
|
||||
_, exists := model["name"]
|
||||
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
||||
@@ -502,6 +520,10 @@ func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
model1Config.Proxy = "http://localhost:10001/"
|
||||
@@ -650,8 +672,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
Running []struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Cmd string `json:"cmd"`
|
||||
Proxy string `json:"proxy"`
|
||||
TTL int `json:"ttl"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"running"`
|
||||
}
|
||||
|
||||
@@ -699,6 +726,11 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
|
||||
// Verify extended fields are present
|
||||
assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated")
|
||||
assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated")
|
||||
assert.Equal(t, 0, response.Running[0].TTL, "ttl should default to 0")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -818,6 +850,43 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioVoicesGETHandler(t *testing.T) {
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("successful GET with model query param", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices?model=model1", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "voice1")
|
||||
})
|
||||
|
||||
t.Run("missing model query param returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices?model=nonexistent", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -944,7 +1013,9 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig("model1")
|
||||
modelConfig.Filters = config.ModelFilters{
|
||||
StripParams: "temperature, model, stream",
|
||||
Filters: config.Filters{
|
||||
StripParams: "temperature, model, stream",
|
||||
},
|
||||
}
|
||||
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
@@ -1078,7 +1149,8 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"author/model": getTestSimpleResponderConfig("author/model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
@@ -1091,6 +1163,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
"/logs/stream",
|
||||
"/logs/stream/proxy",
|
||||
"/logs/stream/upstream",
|
||||
"/logs/stream/author/model",
|
||||
}
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
@@ -1185,3 +1258,349 @@ func TestProxyManager_ApiGetVersion(t *testing.T) {
|
||||
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_APIKeyAuth(t *testing.T) {
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
t.Run("valid key in x-api-key header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("valid key in Authorization Bearer header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer valid-key-2")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("both headers with matching keys", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
req.Header.Set("Authorization", "Bearer valid-key-1")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid key returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "invalid-key")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||
})
|
||||
|
||||
t.Run("missing key returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
})
|
||||
|
||||
t.Run("valid key in Basic Auth header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
// Basic Auth: base64("anyuser:valid-key-1")
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||
})
|
||||
|
||||
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
|
||||
// Config without RequiredAPIKeys - auth should be disabled
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
t.Run("requests pass without API key when not configured", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration
|
||||
// in proxyInferenceHandler for issue #433
|
||||
func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) {
|
||||
t.Run("requests to peer models are proxied", func(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
// Create config with peers but no local model for "peer-model"
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "from-peer")
|
||||
})
|
||||
|
||||
t.Run("local models take precedence over peer models", func(t *testing.T) {
|
||||
// Create a test server to act as the peer - should NOT be called
|
||||
peerCalled := false
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
peerCalled = true
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"from-peer"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
// Create config where "shared-model" exists both locally and on peer
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- shared-model
|
||||
models:
|
||||
shared-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-response
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"shared-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "local-response")
|
||||
assert.False(t, peerCalled, "peer should not be called when local model exists")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns error", func(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"unknown-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
||||
})
|
||||
|
||||
t.Run("peer API key is injected into request", func(t *testing.T) {
|
||||
var receivedAuthHeader string
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"ok"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
apiKey: secret-peer-key
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader)
|
||||
})
|
||||
|
||||
t.Run("no peers configured - unknown model returns error", func(t *testing.T) {
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"local-model": getTestSimpleResponderConfig("local-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
// peerProxy exists but has no peer models configured
|
||||
assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model"))
|
||||
|
||||
reqBody := `{"model":"unknown-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
||||
})
|
||||
|
||||
t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) {
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("data: test\n\n"))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// selectEncoding chooses the best encoding based on Accept-Encoding header
|
||||
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
|
||||
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||
if acceptEncoding == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "br" {
|
||||
return "br", ".br"
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "gzip" {
|
||||
return "gzip", ".gz"
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// ServeCompressedFile serves a file with compression support.
|
||||
// It checks for pre-compressed versions and serves them with proper headers.
|
||||
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
|
||||
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
|
||||
|
||||
// Try to serve compressed version if client supports it
|
||||
if encoding != "" {
|
||||
if cf, err := fs.Open(name + ext); err == nil {
|
||||
defer cf.Close()
|
||||
|
||||
// Verify it's a regular file (not a directory)
|
||||
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||
// Set the content encoding header
|
||||
w.Header().Set("Content-Encoding", encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
|
||||
// Get original file info for content type detection
|
||||
origFile, err := fs.Open(name)
|
||||
if err == nil {
|
||||
origFile.Close()
|
||||
}
|
||||
|
||||
// Serve the compressed file
|
||||
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serving the uncompressed file
|
||||
file, err := fs.Open(name)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
http.Error(w, "is a directory", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServeCompressedFile_Brotli(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with brotli")
|
||||
brContent := []byte("fake-brotli-compressed-data")
|
||||
|
||||
// Create a test filesystem
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: brContent, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check that brotli is used (preferred over gzip)
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
|
||||
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, brContent) {
|
||||
t.Errorf("Expected brotli content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_Gzip(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with gzip")
|
||||
gzContent := []byte("fake-gzip-compressed-data")
|
||||
|
||||
// Create a test filesystem without brotli
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, gzContent) {
|
||||
t.Errorf("Expected gzip content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is uncompressed test content")
|
||||
|
||||
// Create a test filesystem without compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should not have Content-Encoding header since we're serving uncompressed
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content")
|
||||
|
||||
// Create a test filesystem with compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
// No Accept-Encoding header
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should serve uncompressed content
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NotFound(t *testing.T) {
|
||||
mapFS := fstest.MapFS{}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "nonexistent.js")
|
||||
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectEncoding(t *testing.T) {
|
||||
tests := []struct {
|
||||
acceptEncoding string
|
||||
wantEncoding string
|
||||
wantExt string
|
||||
}{
|
||||
{"br, gzip", "br", ".br"},
|
||||
{"gzip, deflate", "gzip", ".gz"},
|
||||
{"gzip", "gzip", ".gz"},
|
||||
{"br", "br", ".br"},
|
||||
{"", "", ""},
|
||||
{"deflate", "", ""},
|
||||
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
|
||||
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
|
||||
{"browser", "", ""},
|
||||
{"compress, deflate", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
|
||||
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
|
||||
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
|
||||
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual pre-compressed files from ui_dist
|
||||
func TestServeCompressedFile_RealFiles(t *testing.T) {
|
||||
// Check if ui_dist exists
|
||||
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
|
||||
t.Skip("ui_dist not found, skipping real file test")
|
||||
}
|
||||
|
||||
// Find a .js or .css file that has compressed versions
|
||||
entries, err := os.ReadDir("./ui_dist/assets")
|
||||
if err != nil {
|
||||
t.Skipf("Could not read ui_dist/assets: %v", err)
|
||||
}
|
||||
|
||||
var testFile string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
|
||||
// Check if compressed versions exist
|
||||
base := strings.TrimSuffix(name, ".js")
|
||||
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
|
||||
testFile = "assets/" + name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if testFile == "" {
|
||||
t.Skip("No suitable test file found with compressed versions")
|
||||
}
|
||||
|
||||
fs := http.FS(os.DirFS("./ui_dist"))
|
||||
|
||||
// Test brotli
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
})
|
||||
|
||||
// Test gzip
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
// Verify it's valid gzip
|
||||
reader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid gzip content: %v", err)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Just read to verify it's valid
|
||||
_, err = io.Copy(io.Discard, reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decompress gzip: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
node_modules
|
||||
.vite
|
||||
@@ -10,8 +10,8 @@
|
||||
<link rel="manifest" href="/site.webmanifest" />
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body >
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.ts"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"name": "ui-svelte",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"start": "vite",
|
||||
"build": "vite build --emptyOutDir",
|
||||
"preview": "vite preview",
|
||||
"check": "svelte-check --tsconfig ./tsconfig.json",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@sveltejs/vite-plugin-svelte": "^5.0.3",
|
||||
"@tailwindcss/vite": "^4.1.8",
|
||||
"@tsconfig/svelte": "^5.0.4",
|
||||
"@types/hast": "^3.0.4",
|
||||
"@types/node": "^25.1.0",
|
||||
"svelte": "^5.19.0",
|
||||
"svelte-check": "^4.1.4",
|
||||
"tailwindcss": "^4.1.8",
|
||||
"typescript": "~5.8.3",
|
||||
"vite": "^6.3.5",
|
||||
"vite-plugin-compression2": "^2.4.0",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.28",
|
||||
"lucide-svelte": "^0.563.0",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"remark-parse": "^11.0.0",
|
||||
"remark-rehype": "^11.1.2",
|
||||
"svelte-spa-router": "^4.0.1",
|
||||
"unified": "^11.0.5",
|
||||
"unist-util-visit": "^5.1.0"
|
||||
}
|
||||
}
|
||||
|
Before Width: | Height: | Size: 5.9 KiB After Width: | Height: | Size: 5.9 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 38 KiB After Width: | Height: | Size: 38 KiB |
|
Before Width: | Height: | Size: 6.5 KiB After Width: | Height: | Size: 6.5 KiB |
|
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 28 KiB |
@@ -0,0 +1,58 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import Router from "svelte-spa-router";
|
||||
import Header from "./components/Header.svelte";
|
||||
import LogViewer from "./routes/LogViewer.svelte";
|
||||
import Models from "./routes/Models.svelte";
|
||||
import Activity from "./routes/Activity.svelte";
|
||||
import Playground from "./routes/Playground.svelte";
|
||||
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
||||
import { enableAPIEvents } from "./stores/api";
|
||||
import { initScreenWidth, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
||||
import { currentRoute } from "./stores/route";
|
||||
|
||||
const routes = {
|
||||
"/": PlaygroundStub,
|
||||
"/models": Models,
|
||||
"/logs": LogViewer,
|
||||
"/activity": Activity,
|
||||
"*": PlaygroundStub,
|
||||
};
|
||||
|
||||
function handleRouteLoaded(event: { detail: { route: string | RegExp } }) {
|
||||
const route = event.detail.route;
|
||||
currentRoute.set(typeof route === "string" ? route : "/");
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
document.documentElement.setAttribute("data-theme", $isDarkMode ? "dark" : "light");
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
const icon = $connectionState === "connecting" ? "\u{1F7E1}" : $connectionState === "connected" ? "\u{1F7E2}" : "\u{1F534}";
|
||||
document.title = `${icon} ${$appTitle}`;
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
const cleanupScreenWidth = initScreenWidth();
|
||||
enableAPIEvents(true);
|
||||
|
||||
return () => {
|
||||
cleanupScreenWidth();
|
||||
enableAPIEvents(false);
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-screen">
|
||||
<Header />
|
||||
|
||||
<main class="flex-1 overflow-auto p-4">
|
||||
<div class="h-full" class:hidden={$currentRoute !== "/"}>
|
||||
<Playground />
|
||||
</div>
|
||||
<div class="h-full" class:hidden={$currentRoute === "/"}>
|
||||
<Router {routes} on:routeLoaded={handleRouteLoaded} />
|
||||
</div>
|
||||
</main>
|
||||
</div>
|
||||
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 4.0 KiB After Width: | Height: | Size: 4.0 KiB |
@@ -0,0 +1,452 @@
|
||||
<script lang="ts">
|
||||
import type { ReqRespCapture } from "../lib/types";
|
||||
|
||||
interface Props {
|
||||
capture: ReqRespCapture | null;
|
||||
open: boolean;
|
||||
onclose: () => void;
|
||||
}
|
||||
|
||||
let { capture, open, onclose }: Props = $props();
|
||||
|
||||
let dialogEl: HTMLDialogElement | undefined = $state();
|
||||
|
||||
type BodyTab = "raw" | "pretty" | "chat";
|
||||
let reqBodyTab: BodyTab = $state("pretty");
|
||||
let respBodyTab: BodyTab = $state("pretty");
|
||||
let copiedReq = $state(false);
|
||||
let copiedResp = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
if (open && dialogEl) {
|
||||
dialogEl.showModal();
|
||||
} else if (!open && dialogEl) {
|
||||
dialogEl.close();
|
||||
}
|
||||
});
|
||||
|
||||
// Reset tabs when capture changes
|
||||
$effect(() => {
|
||||
if (capture) {
|
||||
const reqCt = getContentType(capture.req_headers);
|
||||
const respCt = getContentType(capture.resp_headers);
|
||||
reqBodyTab = reqCt.includes("json") ? "pretty" : "raw";
|
||||
respBodyTab = respCt.includes("text/event-stream")
|
||||
? "chat"
|
||||
: respCt.includes("json")
|
||||
? "pretty"
|
||||
: "raw";
|
||||
}
|
||||
});
|
||||
|
||||
function handleDialogClose() {
|
||||
onclose();
|
||||
}
|
||||
|
||||
function decodeBody(body: string | null | undefined): string {
|
||||
if (!body) return "";
|
||||
try {
|
||||
const binary = atob(body);
|
||||
const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
|
||||
return new TextDecoder().decode(bytes);
|
||||
} catch {
|
||||
return body;
|
||||
}
|
||||
}
|
||||
|
||||
function formatJson(str: string): string {
|
||||
try {
|
||||
const parsed = JSON.parse(str);
|
||||
return JSON.stringify(parsed, null, 2);
|
||||
} catch {
|
||||
return str;
|
||||
}
|
||||
}
|
||||
|
||||
function getContentType(
|
||||
headers: Record<string, string> | null | undefined,
|
||||
): string {
|
||||
if (!headers) return "";
|
||||
const ct = headers["Content-Type"] || headers["content-type"] || "";
|
||||
return ct.toLowerCase();
|
||||
}
|
||||
|
||||
function isImageContentType(contentType: string): boolean {
|
||||
return contentType.startsWith("image/");
|
||||
}
|
||||
|
||||
function isTextContentType(contentType: string): boolean {
|
||||
return (
|
||||
contentType.startsWith("text/") ||
|
||||
contentType.includes("application/json") ||
|
||||
contentType.includes("application/xml") ||
|
||||
contentType.includes("application/javascript")
|
||||
);
|
||||
}
|
||||
|
||||
function getImageDataUrl(body: string, contentType: string): string {
|
||||
const mimeType = contentType.split(";")[0].trim();
|
||||
return `data:${mimeType};base64,${body}`;
|
||||
}
|
||||
|
||||
interface SSEChat {
|
||||
reasoning: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
function parseSSEChat(text: string): SSEChat {
|
||||
const result: SSEChat = { reasoning: "", content: "" };
|
||||
for (const line of text.split("\n")) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || !trimmed.startsWith("data: ")) continue;
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
const delta = parsed.choices?.[0]?.delta;
|
||||
if (delta?.content) result.content += delta.content;
|
||||
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
|
||||
} catch {
|
||||
// skip unparseable lines
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
async function copyToClipboard(text: string, type: "req" | "resp") {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
if (type === "req") {
|
||||
copiedReq = true;
|
||||
setTimeout(() => (copiedReq = false), 1500);
|
||||
} else {
|
||||
copiedResp = true;
|
||||
setTimeout(() => (copiedResp = false), 1500);
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
function getCopyText(): string {
|
||||
if (respBodyTab === "chat") {
|
||||
let text = "";
|
||||
if (sseChat.reasoning) text += sseChat.reasoning + "\n\n";
|
||||
text += sseChat.content;
|
||||
return text;
|
||||
}
|
||||
return displayedResponseBody;
|
||||
}
|
||||
|
||||
// Request body derivations
|
||||
let requestContentType = $derived(
|
||||
capture ? getContentType(capture.req_headers) : "",
|
||||
);
|
||||
let isRequestJson = $derived(requestContentType.includes("json"));
|
||||
|
||||
let requestBodyRaw = $derived.by(() => {
|
||||
if (!capture) return "";
|
||||
return decodeBody(capture.req_body);
|
||||
});
|
||||
|
||||
let requestBodyPretty = $derived.by(() => {
|
||||
if (!isRequestJson) return requestBodyRaw;
|
||||
return formatJson(requestBodyRaw);
|
||||
});
|
||||
|
||||
let displayedRequestBody = $derived(
|
||||
reqBodyTab === "pretty" ? requestBodyPretty : requestBodyRaw,
|
||||
);
|
||||
|
||||
// Response body derivations
|
||||
let responseContentType = $derived(
|
||||
capture ? getContentType(capture.resp_headers) : "",
|
||||
);
|
||||
let isResponseImage = $derived(isImageContentType(responseContentType));
|
||||
let isResponseText = $derived(isTextContentType(responseContentType));
|
||||
let isResponseJson = $derived(responseContentType.includes("json"));
|
||||
let isSSE = $derived(responseContentType.includes("text/event-stream"));
|
||||
|
||||
let responseBodyRaw = $derived.by(() => {
|
||||
if (!capture) return "";
|
||||
return decodeBody(capture.resp_body);
|
||||
});
|
||||
|
||||
let responseBodyPretty = $derived.by(() => {
|
||||
if (!isResponseJson) return responseBodyRaw;
|
||||
return formatJson(responseBodyRaw);
|
||||
});
|
||||
|
||||
let sseChat = $derived.by(() => {
|
||||
if (!isSSE || !responseBodyRaw)
|
||||
return { reasoning: "", content: "" } as SSEChat;
|
||||
return parseSSEChat(responseBodyRaw);
|
||||
});
|
||||
|
||||
let displayedResponseBody = $derived.by(() => {
|
||||
if (respBodyTab === "pretty") return responseBodyPretty;
|
||||
return responseBodyRaw;
|
||||
});
|
||||
</script>
|
||||
|
||||
<dialog
|
||||
bind:this={dialogEl}
|
||||
onclose={handleDialogClose}
|
||||
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-4xl w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
|
||||
>
|
||||
{#if capture}
|
||||
<div class="flex flex-col max-h-[90vh]">
|
||||
<div
|
||||
class="flex justify-between items-center p-4 border-b border-card-border"
|
||||
>
|
||||
<h2 class="text-xl font-bold pb-0">Capture #{capture.id + 1}{#if capture.req_path} <span class="text-base font-mono font-normal text-txtsecondary">{capture.req_path}</span>{/if}</h2>
|
||||
<button
|
||||
onclick={() => dialogEl?.close()}
|
||||
class="text-txtsecondary hover:text-txtmain text-2xl leading-none"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="overflow-y-auto flex-1 p-4 space-y-4">
|
||||
<!-- Request Headers -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Request Headers
|
||||
</summary>
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
|
||||
>
|
||||
<table class="w-full text-sm">
|
||||
<tbody>
|
||||
{#each Object.entries(capture.req_headers || {}) as [key, value]}
|
||||
<tr class="border-b border-card-border-inner last:border-0">
|
||||
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
|
||||
>{key}</td
|
||||
>
|
||||
<td class="px-3 py-1 font-mono break-all">{value}</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Request Body -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Request Body
|
||||
</summary>
|
||||
{#if requestBodyRaw}
|
||||
<div class="mt-2 flex items-center justify-between">
|
||||
<div class="flex gap-1">
|
||||
{#if isRequestJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={reqBodyTab === "pretty"}
|
||||
onclick={() => (reqBodyTab = "pretty")}>Pretty</button
|
||||
>
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={reqBodyTab === "raw"}
|
||||
onclick={() => (reqBodyTab = "raw")}>Raw</button
|
||||
>
|
||||
{/if}
|
||||
</div>
|
||||
<button
|
||||
class="tab-btn"
|
||||
onclick={() =>
|
||||
copyToClipboard(displayedRequestBody, "req")}
|
||||
>
|
||||
{#if copiedReq}
|
||||
Copied!
|
||||
{:else}
|
||||
Copy
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre
|
||||
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedRequestBody}</pre>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre class="p-3 text-sm font-mono whitespace-pre-wrap break-all"
|
||||
>(empty)</pre
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</details>
|
||||
|
||||
<!-- Response Headers -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Response Headers
|
||||
</summary>
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
|
||||
>
|
||||
<table class="w-full text-sm">
|
||||
<tbody>
|
||||
{#each Object.entries(capture.resp_headers || {}) as [key, value]}
|
||||
<tr class="border-b border-card-border-inner last:border-0">
|
||||
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
|
||||
>{key}</td
|
||||
>
|
||||
<td class="px-3 py-1 font-mono break-all">{value}</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Response Body -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Response Body
|
||||
</summary>
|
||||
{#if isResponseImage && capture.resp_body}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<div class="p-3 flex justify-center">
|
||||
<img
|
||||
src={getImageDataUrl(capture.resp_body, responseContentType)}
|
||||
alt="Response"
|
||||
class="max-w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{:else if isSSE || isResponseText}
|
||||
<div class="mt-2 flex items-center justify-between">
|
||||
<div class="flex gap-1">
|
||||
{#if isSSE}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "chat"}
|
||||
onclick={() => (respBodyTab = "chat")}>Chat</button
|
||||
>
|
||||
{/if}
|
||||
{#if isResponseJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "pretty"}
|
||||
onclick={() => (respBodyTab = "pretty")}>Pretty</button
|
||||
>
|
||||
{/if}
|
||||
{#if isSSE || isResponseJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "raw"}
|
||||
onclick={() => (respBodyTab = "raw")}>Raw</button
|
||||
>
|
||||
{/if}
|
||||
</div>
|
||||
<button
|
||||
class="tab-btn"
|
||||
onclick={() => copyToClipboard(getCopyText(), "resp")}
|
||||
>
|
||||
{#if copiedResp}
|
||||
Copied!
|
||||
{:else}
|
||||
Copy
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
{#if respBodyTab === "chat"}
|
||||
<div class="p-3 text-sm space-y-3">
|
||||
{#if sseChat.reasoning}
|
||||
<div>
|
||||
<div
|
||||
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
|
||||
>
|
||||
Reasoning
|
||||
</div>
|
||||
<pre
|
||||
class="font-mono whitespace-pre-wrap break-all text-txtsecondary">{sseChat.reasoning}</pre>
|
||||
</div>
|
||||
{/if}
|
||||
{#if sseChat.content}
|
||||
<div>
|
||||
{#if sseChat.reasoning}
|
||||
<div
|
||||
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
|
||||
>
|
||||
Response
|
||||
</div>
|
||||
{/if}
|
||||
<pre
|
||||
class="font-mono whitespace-pre-wrap break-all">{sseChat.content}</pre>
|
||||
</div>
|
||||
{/if}
|
||||
{#if !sseChat.reasoning && !sseChat.content}
|
||||
<pre class="font-mono">(empty)</pre>
|
||||
{/if}
|
||||
</div>
|
||||
{:else}
|
||||
<pre
|
||||
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedResponseBody || "(empty)"}</pre>
|
||||
{/if}
|
||||
</div>
|
||||
{:else if responseBodyRaw}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<div class="p-3 text-sm text-txtsecondary italic">
|
||||
(binary data - {responseContentType || "unknown content type"})
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre class="p-3 text-sm font-mono">(empty)</pre>
|
||||
</div>
|
||||
{/if}
|
||||
</details>
|
||||
</div>
|
||||
|
||||
<div class="p-4 border-t border-card-border flex justify-end">
|
||||
<button onclick={() => dialogEl?.close()} class="btn"> Close </button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</dialog>
|
||||
|
||||
<style>
|
||||
.tab-btn {
|
||||
padding: 2px 10px;
|
||||
font-size: 0.75rem;
|
||||
border-radius: 4px;
|
||||
color: var(--color-txtsecondary);
|
||||
cursor: pointer;
|
||||
border: 1px solid transparent;
|
||||
background: transparent;
|
||||
transition: all 0.15s;
|
||||
}
|
||||
.tab-btn:hover {
|
||||
color: var(--color-txtmain);
|
||||
background: var(--color-secondary);
|
||||
}
|
||||
.tab-btn-active {
|
||||
color: var(--color-primary);
|
||||
background: color-mix(in srgb, var(--color-primary) 12%, transparent);
|
||||
border-color: color-mix(in srgb, var(--color-primary) 25%, transparent);
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,24 @@
|
||||
<script lang="ts">
|
||||
import { connectionState } from "../stores/theme";
|
||||
import { versionInfo } from "../stores/api";
|
||||
|
||||
let eventStatusColor = $derived.by(() => {
|
||||
switch ($connectionState) {
|
||||
case "connected":
|
||||
return "bg-emerald-500";
|
||||
case "connecting":
|
||||
return "bg-amber-500";
|
||||
case "disconnected":
|
||||
default:
|
||||
return "bg-red-500";
|
||||
}
|
||||
});
|
||||
|
||||
let tooltipText = $derived(
|
||||
`Event Stream: ${$connectionState ?? "unknown"}\nAPI Version: ${$versionInfo?.version ?? "unknown"}\nCommit Hash: ${$versionInfo?.commit?.substring(0, 7) ?? "unknown"}\nBuild Date: ${$versionInfo?.build_date ?? "unknown"}`
|
||||
);
|
||||
</script>
|
||||
|
||||
<div class="flex items-center" title={tooltipText}>
|
||||
<span class="inline-block w-3 h-3 rounded-full {eventStatusColor} mr-2"></span>
|
||||
</div>
|
||||
@@ -0,0 +1,120 @@
|
||||
<script lang="ts">
|
||||
import { link } from "svelte-spa-router";
|
||||
import { screenWidth, toggleTheme, isDarkMode, appTitle, isNarrow } from "../stores/theme";
|
||||
import { currentRoute } from "../stores/route";
|
||||
import { playgroundActivity } from "../stores/playgroundActivity";
|
||||
import ConnectionStatus from "./ConnectionStatus.svelte";
|
||||
|
||||
function handleTitleChange(newTitle: string): void {
|
||||
const sanitized = newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap";
|
||||
appTitle.set(sanitized);
|
||||
}
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent): void {
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
const target = e.currentTarget as HTMLElement;
|
||||
handleTitleChange(target.textContent || "(set title)");
|
||||
target.blur();
|
||||
}
|
||||
}
|
||||
|
||||
function handleBlur(e: FocusEvent): void {
|
||||
const target = e.currentTarget as HTMLElement;
|
||||
handleTitleChange(target.textContent || "(set title)");
|
||||
}
|
||||
|
||||
function isActive(path: string, current: string): boolean {
|
||||
return path === "/" ? current === "/" : current.startsWith(path);
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<header
|
||||
class="flex items-center justify-between bg-surface border-b border-border px-4 {$isNarrow
|
||||
? 'py-1 h-[60px]'
|
||||
: 'p-2 h-[75px]'}"
|
||||
>
|
||||
{#if $screenWidth !== "xs" && $screenWidth !== "sm"}
|
||||
<h1
|
||||
contenteditable="true"
|
||||
class="p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
|
||||
onblur={handleBlur}
|
||||
onkeydown={handleKeyDown}
|
||||
>
|
||||
{$appTitle}
|
||||
</h1>
|
||||
{/if}
|
||||
|
||||
<menu class="flex items-center gap-4 overflow-x-auto">
|
||||
<a
|
||||
href="/"
|
||||
use:link
|
||||
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
|
||||
>
|
||||
Playground
|
||||
</a>
|
||||
<a
|
||||
href="/models"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/models", $currentRoute)}
|
||||
>
|
||||
Models
|
||||
</a>
|
||||
<a
|
||||
href="/activity"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/activity", $currentRoute)}
|
||||
>
|
||||
Activity
|
||||
</a>
|
||||
<a
|
||||
href="/logs"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/logs", $currentRoute)}
|
||||
>
|
||||
Logs
|
||||
</a>
|
||||
<button onclick={toggleTheme} title="Toggle theme">
|
||||
{#if $isDarkMode}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M9.528 1.718a.75.75 0 0 1 .162.819A8.97 8.97 0 0 0 9 6a9 9 0 0 0 9 9 8.97 8.97 0 0 0 3.463-.69.75.75 0 0 1 .981.98 10.503 10.503 0 0 1-9.694 6.46c-5.799 0-10.5-4.7-10.5-10.5 0-4.368 2.667-8.112 6.46-9.694a.75.75 0 0 1 .818.162Z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path
|
||||
d="M12 2.25a.75.75 0 0 1 .75.75v2.25a.75.75 0 0 1-1.5 0V3a.75.75 0 0 1 .75-.75ZM7.5 12a4.5 4.5 0 1 1 9 0 4.5 4.5 0 0 1-9 0ZM18.894 6.166a.75.75 0 0 0-1.06-1.06l-1.591 1.59a.75.75 0 1 0 1.06 1.061l1.591-1.59ZM21.75 12a.75.75 0 0 1-.75.75h-2.25a.75.75 0 0 1 0-1.5H21a.75.75 0 0 1 .75.75ZM17.834 18.894a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 1 0-1.061 1.06l1.591 1.591ZM12 18a.75.75 0 0 1 .75.75V21a.75.75 0 0 1-1.5 0v-2.25A.75.75 0 0 1 12 18ZM7.758 17.303a.75.75 0 0 0-1.061-1.06l-1.591 1.59a.75.75 0 0 0 1.06 1.061l1.591-1.59ZM6 12a.75.75 0 0 1-.75.75H3a.75.75 0 0 1 0-1.5h2.25A.75.75 0 0 1 6 12ZM6.697 7.757a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 0 0-1.061 1.06l1.59 1.591Z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
<ConnectionStatus />
|
||||
</menu>
|
||||
</header>
|
||||
|
||||
<style>
|
||||
.activity-link {
|
||||
background: linear-gradient(90deg, #6366f1, #8b5cf6, #a855f7, #8b5cf6, #6366f1);
|
||||
background-size: 200% 100%;
|
||||
-webkit-background-clip: text;
|
||||
background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
animation: gradient-shift 2s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes gradient-shift {
|
||||
0% {
|
||||
background-position: 0% 50%;
|
||||
}
|
||||
100% {
|
||||
background-position: 200% 50%;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,132 @@
|
||||
<script lang="ts">
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
title: string;
|
||||
logData: string;
|
||||
}
|
||||
|
||||
let { id, title, logData }: Props = $props();
|
||||
|
||||
let filterRegex = $state("");
|
||||
|
||||
// Create persistent stores for this panel (id is intentionally captured at init time)
|
||||
// svelte-ignore state_referenced_locally
|
||||
const fontSizeStore = persistentStore<"xxs" | "xs" | "small" | "normal">(`logPanel-${id}-fontSize`, "normal");
|
||||
// svelte-ignore state_referenced_locally
|
||||
const wrapTextStore = persistentStore<boolean>(`logPanel-${id}-wrapText`, false);
|
||||
// svelte-ignore state_referenced_locally
|
||||
const showFilterStore = persistentStore<boolean>(`logPanel-${id}-showFilter`, false);
|
||||
|
||||
let textWrapClass = $derived($wrapTextStore ? "whitespace-pre-wrap" : "whitespace-pre");
|
||||
|
||||
function toggleFontSize(): void {
|
||||
fontSizeStore.update((prev) => {
|
||||
switch (prev) {
|
||||
case "xxs": return "xs";
|
||||
case "xs": return "small";
|
||||
case "small": return "normal";
|
||||
case "normal": return "xxs";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function toggleWrapText(): void {
|
||||
wrapTextStore.update((prev) => !prev);
|
||||
}
|
||||
|
||||
function toggleFilter(): void {
|
||||
if ($showFilterStore) {
|
||||
showFilterStore.set(false);
|
||||
filterRegex = "";
|
||||
} else {
|
||||
showFilterStore.set(true);
|
||||
}
|
||||
}
|
||||
|
||||
let fontSizeClass = $derived.by(() => {
|
||||
switch ($fontSizeStore) {
|
||||
case "xxs": return "text-[0.5rem]";
|
||||
case "xs": return "text-[0.75rem]";
|
||||
case "small": return "text-[0.875rem]";
|
||||
case "normal": return "text-base";
|
||||
}
|
||||
});
|
||||
|
||||
let filteredLogs = $derived.by(() => {
|
||||
if (!filterRegex) return logData;
|
||||
try {
|
||||
const regex = new RegExp(filterRegex, "i");
|
||||
return logData.split("\n").filter((line) => regex.test(line)).join("\n");
|
||||
} catch {
|
||||
return logData;
|
||||
}
|
||||
});
|
||||
|
||||
let preElement: HTMLPreElement;
|
||||
|
||||
// Auto scroll to bottom when logs change
|
||||
$effect(() => {
|
||||
if (preElement && filteredLogs) {
|
||||
preElement.scrollTop = preElement.scrollHeight;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="rounded-lg overflow-hidden flex flex-col bg-gray-950/5 dark:bg-white/10 h-full w-full p-1">
|
||||
<div class="p-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<h3 class="m-0 text-lg p-0">{title}</h3>
|
||||
|
||||
<div class="flex gap-2 items-center">
|
||||
<button class="btn border-0" onclick={toggleFontSize} title="Change font size">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M10.5 3.75a6 6 0 0 0-5.98 6.496A5.25 5.25 0 0 0 6.75 20.25H18a4.5 4.5 0 0 0 2.206-8.423 3.75 3.75 0 0 0-4.133-4.303A6.001 6.001 0 0 0 10.5 3.75Zm2.25 6a.75.75 0 0 0-1.5 0v4.94l-1.72-1.72a.75.75 0 0 0-1.06 1.06l3 3a.75.75 0 0 0 1.06 0l3-3a.75.75 0 1 0-1.06-1.06l-1.72 1.72V9.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
<button class="btn border-0" onclick={toggleWrapText} title="Toggle text wrap">
|
||||
{#if $wrapTextStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h10.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
<button class="btn border-0" onclick={toggleFilter} title="Toggle filter">
|
||||
{#if $showFilterStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M10.5 3.75a6.75 6.75 0 1 0 0 13.5 6.75 6.75 0 0 0 0-13.5ZM2.25 10.5a8.25 8.25 0 1 1 14.59 5.28l4.69 4.69a.75.75 0 1 1-1.06 1.06l-4.69-4.69A8.25 8.25 0 0 1 2.25 10.5Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-4 h-4">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="m21 21-5.197-5.197m0 0A7.5 7.5 0 1 0 5.196 5.196a7.5 7.5 0 0 0 10.607 10.607Z" />
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if $showFilterStore}
|
||||
<div class="mt-2 flex gap-2 items-center w-full">
|
||||
<input
|
||||
type="text"
|
||||
class="w-full text-sm border border-gray-950/10 dark:border-white/5 p-2 rounded outline-none"
|
||||
placeholder="Filter logs (regex)..."
|
||||
bind:value={filterRegex}
|
||||
/>
|
||||
<button class="pl-2" onclick={() => (filterRegex = "")} aria-label="Clear filter">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm-1.72 6.97a.75.75 0 1 0-1.06 1.06L10.94 12l-1.72 1.72a.75.75 0 1 0 1.06 1.06L12 13.06l1.72 1.72a.75.75 0 1 0 1.06-1.06L13.06 12l1.72-1.72a.75.75 0 1 0-1.06-1.06L12 10.94l-1.72-1.72Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="rounded-lg bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||
<pre bind:this={preElement} class="{textWrapClass} {fontSizeClass} h-full overflow-auto p-4">{filteredLogs}</pre>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,208 @@
|
||||
<script lang="ts">
|
||||
import { models, loadModel, unloadAllModels, unloadSingleModel } from "../stores/api";
|
||||
import { isNarrow } from "../stores/theme";
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
import type { Model } from "../lib/types";
|
||||
|
||||
let isUnloading = $state(false);
|
||||
let menuOpen = $state(false);
|
||||
|
||||
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
|
||||
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
|
||||
|
||||
let filteredModels = $derived.by(() => {
|
||||
const filtered = $models.filter((model) => $showUnlistedStore || !model.unlisted);
|
||||
const peerModels = filtered.filter((m) => m.peerID);
|
||||
|
||||
// Group peer models by peerID
|
||||
const grouped = peerModels.reduce(
|
||||
(acc, model) => {
|
||||
const peerId = model.peerID || "unknown";
|
||||
if (!acc[peerId]) acc[peerId] = [];
|
||||
acc[peerId].push(model);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, Model[]>
|
||||
);
|
||||
|
||||
return {
|
||||
regularModels: filtered.filter((m) => !m.peerID),
|
||||
peerModelsByPeerId: grouped,
|
||||
};
|
||||
});
|
||||
|
||||
async function handleUnloadAllModels(): Promise<void> {
|
||||
isUnloading = true;
|
||||
try {
|
||||
await unloadAllModels();
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
setTimeout(() => (isUnloading = false), 1000);
|
||||
}
|
||||
}
|
||||
|
||||
function toggleIdorName(): void {
|
||||
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
|
||||
}
|
||||
|
||||
function toggleShowUnlisted(): void {
|
||||
showUnlistedStore.update((prev) => !prev);
|
||||
}
|
||||
|
||||
function getModelDisplay(model: Model): string {
|
||||
return $showIdorNameStore === "id" ? model.id : (model.name || model.id);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="card h-full flex flex-col">
|
||||
<div class="shrink-0">
|
||||
<div class="flex justify-between items-baseline">
|
||||
<h2 class={$isNarrow ? "text-xl" : ""}>Models</h2>
|
||||
{#if $isNarrow}
|
||||
<div class="relative">
|
||||
<button class="btn text-base flex items-center gap-2 py-1" onclick={() => (menuOpen = !menuOpen)} aria-label="Toggle menu">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
{#if menuOpen}
|
||||
<div class="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { toggleIdorName(); menuOpen = false; }}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{$showIdorNameStore === "id" ? "Show Name" : "Show ID"}
|
||||
</button>
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { toggleShowUnlisted(); menuOpen = false; }}
|
||||
>
|
||||
{#if $showUnlistedStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
|
||||
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
|
||||
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
|
||||
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{/if}
|
||||
{$showUnlistedStore ? "Hide Unlisted" : "Show Unlisted"}
|
||||
</button>
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { handleUnloadAllModels(); menuOpen = false; }}
|
||||
disabled={isUnloading}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{#if !$isNarrow}
|
||||
<div class="flex justify-between">
|
||||
<div class="flex gap-2">
|
||||
<button class="btn text-base flex items-center gap-2" onclick={toggleIdorName} style="line-height: 1.2">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{$showIdorNameStore === "id" ? "ID" : "Name"}
|
||||
</button>
|
||||
|
||||
<button class="btn text-base flex items-center gap-2" onclick={toggleShowUnlisted} style="line-height: 1.2">
|
||||
{#if $showUnlistedStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
|
||||
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
|
||||
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
|
||||
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
|
||||
</svg>
|
||||
{/if}
|
||||
unlisted
|
||||
</button>
|
||||
</div>
|
||||
<button class="btn text-base flex items-center gap-2" onclick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex-1 overflow-y-auto">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 bg-card z-10">
|
||||
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||
<th>{$showIdorNameStore === "id" ? "Model ID" : "Name"}</th>
|
||||
<th></th>
|
||||
<th>State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each filteredModels.regularModels as model (model.id)}
|
||||
<tr class="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td class={model.unlisted ? "text-txtsecondary" : ""}>
|
||||
<a href="/upstream/{model.id}/" class="font-semibold" target="_blank">
|
||||
{getModelDisplay(model)}
|
||||
</a>
|
||||
{#if model.description}
|
||||
<p class={model.unlisted ? "text-opacity-70" : ""}><em>{model.description}</em></p>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="w-12">
|
||||
{#if model.state === "stopped"}
|
||||
<button class="btn btn--sm" onclick={() => loadModel(model.id)}>Load</button>
|
||||
{:else}
|
||||
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="w-20">
|
||||
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{#if Object.keys(filteredModels.peerModelsByPeerId).length > 0}
|
||||
<h3 class="mt-8 mb-2">Peer Models</h3>
|
||||
{#each Object.entries(filteredModels.peerModelsByPeerId).sort(([a], [b]) => a.localeCompare(b)) as [peerId, peerModels] (peerId)}
|
||||
<div class="mb-4">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 bg-card z-10">
|
||||
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||
<th class="font-semibold">{peerId}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each peerModels as model (model.id)}
|
||||
<tr class="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td class="pl-8 {model.unlisted ? 'text-txtsecondary' : ''}">
|
||||
<span>{model.id}</span>
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,152 @@
|
||||
<script lang="ts">
|
||||
import type { Snippet } from "svelte";
|
||||
import { onMount } from "svelte";
|
||||
|
||||
interface Props {
|
||||
direction: "horizontal" | "vertical";
|
||||
storageKey: string;
|
||||
leftPanel: Snippet;
|
||||
rightPanel: Snippet;
|
||||
defaultSize?: number;
|
||||
minSize?: number;
|
||||
}
|
||||
|
||||
let { direction, storageKey, leftPanel, rightPanel, defaultSize = 50, minSize = 5 }: Props = $props();
|
||||
|
||||
let containerRef: HTMLDivElement;
|
||||
let isDragging = $state(false);
|
||||
// svelte-ignore state_referenced_locally
|
||||
let leftSize = $state(defaultSize);
|
||||
|
||||
// Load saved size from localStorage
|
||||
onMount(() => {
|
||||
const saved = localStorage.getItem(`panel-size-${storageKey}`);
|
||||
if (saved) {
|
||||
const parsed = parseFloat(saved);
|
||||
if (!isNaN(parsed) && parsed >= minSize && parsed <= 100 - minSize) {
|
||||
leftSize = parsed;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
function saveSize(): void {
|
||||
localStorage.setItem(`panel-size-${storageKey}`, String(leftSize));
|
||||
}
|
||||
|
||||
function handleMouseDown(e: MouseEvent): void {
|
||||
e.preventDefault();
|
||||
isDragging = true;
|
||||
document.addEventListener("mousemove", handleMouseMove);
|
||||
document.addEventListener("mouseup", handleMouseUp);
|
||||
}
|
||||
|
||||
function handleTouchStart(_e: TouchEvent): void {
|
||||
isDragging = true;
|
||||
document.addEventListener("touchmove", handleTouchMove);
|
||||
document.addEventListener("touchend", handleTouchEnd);
|
||||
}
|
||||
|
||||
function handleMouseMove(e: MouseEvent): void {
|
||||
if (!isDragging || !containerRef) return;
|
||||
updateSize(e.clientX, e.clientY);
|
||||
}
|
||||
|
||||
function handleTouchMove(e: TouchEvent): void {
|
||||
if (!isDragging || !containerRef || e.touches.length === 0) return;
|
||||
updateSize(e.touches[0].clientX, e.touches[0].clientY);
|
||||
}
|
||||
|
||||
function updateSize(clientX: number, clientY: number): void {
|
||||
const rect = containerRef.getBoundingClientRect();
|
||||
|
||||
let newSize: number;
|
||||
if (direction === "horizontal") {
|
||||
newSize = ((clientX - rect.left) / rect.width) * 100;
|
||||
} else {
|
||||
newSize = ((clientY - rect.top) / rect.height) * 100;
|
||||
}
|
||||
|
||||
// Clamp size
|
||||
newSize = Math.max(minSize, Math.min(100 - minSize, newSize));
|
||||
leftSize = newSize;
|
||||
}
|
||||
|
||||
function handleMouseUp(): void {
|
||||
isDragging = false;
|
||||
saveSize();
|
||||
document.removeEventListener("mousemove", handleMouseMove);
|
||||
document.removeEventListener("mouseup", handleMouseUp);
|
||||
}
|
||||
|
||||
function handleTouchEnd(): void {
|
||||
isDragging = false;
|
||||
saveSize();
|
||||
document.removeEventListener("touchmove", handleTouchMove);
|
||||
document.removeEventListener("touchend", handleTouchEnd);
|
||||
}
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent): void {
|
||||
const step = 2; // 2% increment for keyboard navigation
|
||||
const key = e.key;
|
||||
|
||||
if (direction === "horizontal" && (key === "ArrowLeft" || key === "ArrowRight")) {
|
||||
e.preventDefault();
|
||||
const delta = key === "ArrowLeft" ? -step : step;
|
||||
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
|
||||
leftSize = newSize;
|
||||
saveSize();
|
||||
} else if (direction === "vertical" && (key === "ArrowUp" || key === "ArrowDown")) {
|
||||
e.preventDefault();
|
||||
const delta = key === "ArrowUp" ? -step : step;
|
||||
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
|
||||
leftSize = newSize;
|
||||
saveSize();
|
||||
}
|
||||
}
|
||||
|
||||
let containerClass = $derived(direction === "horizontal" ? "flex-row" : "flex-col");
|
||||
|
||||
let handleClass = $derived(
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full cursor-col-resize"
|
||||
: "w-full h-2 cursor-row-resize"
|
||||
);
|
||||
|
||||
let leftStyle = $derived(
|
||||
direction === "horizontal"
|
||||
? `width: ${leftSize}%; min-width: ${minSize}%`
|
||||
: `height: ${leftSize}%; min-height: ${minSize}%`
|
||||
);
|
||||
|
||||
let rightStyle = $derived(
|
||||
direction === "horizontal"
|
||||
? `width: ${100 - leftSize}%; min-width: ${minSize}%`
|
||||
: `height: ${100 - leftSize}%; min-height: ${minSize}%`
|
||||
);
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="flex {containerClass} h-full w-full gap-2">
|
||||
<div style={leftStyle} class="overflow-hidden">
|
||||
{@render leftPanel()}
|
||||
</div>
|
||||
|
||||
<!-- svelte-ignore a11y_no_noninteractive_tabindex -->
|
||||
<!-- svelte-ignore a11y_no_noninteractive_element_interactions -->
|
||||
<div
|
||||
role="separator"
|
||||
tabindex="0"
|
||||
class="{handleClass} bg-primary hover:bg-success transition-colors rounded flex-shrink-0"
|
||||
onmousedown={handleMouseDown}
|
||||
ontouchstart={handleTouchStart}
|
||||
onkeydown={handleKeyDown}
|
||||
aria-label="Resize panels"
|
||||
aria-orientation={direction}
|
||||
aria-valuenow={Math.round(leftSize)}
|
||||
aria-valuemin={minSize}
|
||||
aria-valuemax={100 - minSize}
|
||||
></div>
|
||||
|
||||
<div style={rightStyle} class="overflow-hidden">
|
||||
{@render rightPanel()}
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,167 @@
|
||||
<script lang="ts">
|
||||
import { inFlightRequests, metrics } from "../stores/api";
|
||||
import TokenHistogram from "./TokenHistogram.svelte";
|
||||
|
||||
interface HistogramData {
|
||||
bins: number[];
|
||||
min: number;
|
||||
max: number;
|
||||
binSize: number;
|
||||
p99: number;
|
||||
p95: number;
|
||||
p50: number;
|
||||
}
|
||||
|
||||
let stats = $derived.by(() => {
|
||||
const totalRequests = $metrics.length;
|
||||
if (totalRequests === 0) {
|
||||
return {
|
||||
totalRequests: 0,
|
||||
totalInputTokens: 0,
|
||||
totalOutputTokens: 0,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: { p99: "0", p95: "0", p50: "0" },
|
||||
histogramData: null,
|
||||
};
|
||||
}
|
||||
|
||||
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||
|
||||
// Calculate token statistics using output_tokens and duration_ms
|
||||
const validMetrics = $metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
|
||||
if (validMetrics.length === 0) {
|
||||
return {
|
||||
totalRequests,
|
||||
totalInputTokens,
|
||||
totalOutputTokens,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: { p99: "0", p95: "0", p50: "0" },
|
||||
histogramData: null,
|
||||
};
|
||||
}
|
||||
|
||||
// Calculate tokens/second for each valid metric
|
||||
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
|
||||
|
||||
// Sort for percentile calculation
|
||||
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
|
||||
|
||||
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
|
||||
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
|
||||
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
|
||||
|
||||
// Create histogram data
|
||||
const min = Math.min(...tokensPerSecond);
|
||||
const max = Math.max(...tokensPerSecond);
|
||||
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5)));
|
||||
const binSize = (max - min) / binCount;
|
||||
|
||||
const bins = Array(binCount).fill(0);
|
||||
tokensPerSecond.forEach((value) => {
|
||||
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
|
||||
bins[binIndex]++;
|
||||
});
|
||||
|
||||
const histogramData: HistogramData = {
|
||||
bins,
|
||||
min,
|
||||
max,
|
||||
binSize,
|
||||
p99,
|
||||
p95,
|
||||
p50,
|
||||
};
|
||||
|
||||
return {
|
||||
totalRequests,
|
||||
totalInputTokens,
|
||||
totalOutputTokens,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: {
|
||||
p99: p99.toFixed(2),
|
||||
p95: p95.toFixed(2),
|
||||
p50: p50.toFixed(2),
|
||||
},
|
||||
histogramData,
|
||||
};
|
||||
});
|
||||
|
||||
const nf = new Intl.NumberFormat();
|
||||
</script>
|
||||
|
||||
<div class="card">
|
||||
<div class="rounded-lg overflow-hidden border border-card-border-inner">
|
||||
<table class="min-w-full divide-y divide-card-border-inner">
|
||||
<thead class="bg-secondary">
|
||||
<tr>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">Requests</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Processed
|
||||
</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Generated
|
||||
</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Token Stats (tokens/sec)
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
|
||||
<tbody class="bg-surface divide-y divide-card-border-inner">
|
||||
<tr class="hover:bg-secondary">
|
||||
<td class="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">
|
||||
<div class="flex flex-col gap-1">
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Completed: {nf.format(stats.totalRequests)}</span>
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Waiting: {nf.format(stats.inFlightRequests)}</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-sm font-medium">{nf.format(stats.totalInputTokens)}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-sm font-medium">{nf.format(stats.totalOutputTokens)}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="space-y-3">
|
||||
<div class="grid grid-cols-3 gap-2 items-center">
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P50</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p50}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P95</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p95}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P99</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p99}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{#if stats.histogramData}
|
||||
<TokenHistogram data={stats.histogramData} />
|
||||
{/if}
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,129 @@
|
||||
<script lang="ts">
|
||||
interface HistogramData {
|
||||
bins: number[];
|
||||
min: number;
|
||||
max: number;
|
||||
binSize: number;
|
||||
p99: number;
|
||||
p95: number;
|
||||
p50: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
data: HistogramData;
|
||||
}
|
||||
|
||||
let { data }: Props = $props();
|
||||
|
||||
const height = 120;
|
||||
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
|
||||
const viewBoxWidth = 600;
|
||||
const chartWidth = viewBoxWidth - padding.left - padding.right;
|
||||
const chartHeight = height - padding.top - padding.bottom;
|
||||
|
||||
let maxCount = $derived(Math.max(...data.bins));
|
||||
let barWidth = $derived(chartWidth / data.bins.length);
|
||||
let range = $derived(data.max - data.min);
|
||||
|
||||
function getXPosition(value: number): number {
|
||||
return padding.left + ((value - data.min) / range) * chartWidth;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="mt-2 w-full">
|
||||
<svg viewBox="0 0 {viewBoxWidth} {height}" class="w-full h-auto" preserveAspectRatio="xMidYMid meet">
|
||||
<!-- Y-axis -->
|
||||
<line
|
||||
x1={padding.left}
|
||||
y1={padding.top}
|
||||
x2={padding.left}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="1"
|
||||
opacity="0.3"
|
||||
/>
|
||||
|
||||
<!-- X-axis -->
|
||||
<line
|
||||
x1={padding.left}
|
||||
y1={height - padding.bottom}
|
||||
x2={viewBoxWidth - padding.right}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="1"
|
||||
opacity="0.3"
|
||||
/>
|
||||
|
||||
<!-- Histogram bars -->
|
||||
{#each data.bins as count, i}
|
||||
{@const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0}
|
||||
{@const x = padding.left + i * barWidth}
|
||||
{@const y = height - padding.bottom - barHeight}
|
||||
{@const binStart = data.min + i * data.binSize}
|
||||
{@const binEnd = binStart + data.binSize}
|
||||
<g>
|
||||
<rect
|
||||
{x}
|
||||
{y}
|
||||
width={Math.max(barWidth - 1, 1)}
|
||||
height={barHeight}
|
||||
fill="currentColor"
|
||||
opacity="0.6"
|
||||
class="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
|
||||
/>
|
||||
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
|
||||
</g>
|
||||
{/each}
|
||||
|
||||
<!-- Percentile lines -->
|
||||
<line
|
||||
x1={getXPosition(data.p50)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p50)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-gray-600 dark:text-gray-400"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1={getXPosition(data.p95)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p95)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-orange-500 dark:text-orange-400"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1={getXPosition(data.p99)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p99)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-green-500 dark:text-green-400"
|
||||
/>
|
||||
|
||||
<!-- X-axis labels -->
|
||||
<text x={padding.left} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="start">
|
||||
{data.min.toFixed(1)}
|
||||
</text>
|
||||
|
||||
<text x={viewBoxWidth - padding.right} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="end">
|
||||
{data.max.toFixed(1)}
|
||||
</text>
|
||||
|
||||
<!-- X-axis label -->
|
||||
<text x={padding.left + chartWidth / 2} y={height - 2} font-size="10" fill="currentColor" opacity="0.6" text-anchor="middle">
|
||||
Tokens/Second Distribution
|
||||
</text>
|
||||
</svg>
|
||||
</div>
|
||||
@@ -0,0 +1,20 @@
|
||||
<script lang="ts">
|
||||
interface Props {
|
||||
content: string;
|
||||
}
|
||||
|
||||
let { content }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative group inline-block">
|
||||
<span class="cursor-help">ⓘ</span>
|
||||
<div
|
||||
class="absolute top-full left-1/2 transform -translate-x-1/2 mt-2
|
||||
px-3 py-2 bg-gray-900 text-white text-sm rounded-md
|
||||
opacity-0 group-hover:opacity-100 transition-opacity
|
||||
duration-200 pointer-events-none whitespace-nowrap z-50 normal-case"
|
||||
>
|
||||
{content}
|
||||
<div class="absolute bottom-full left-1/2 transform -translate-x-1/2 border-4 border-transparent border-b-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,256 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { transcribeAudio } from "../../lib/audioApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-audio-model", "");
|
||||
|
||||
let selectedFile = $state<File | null>(null);
|
||||
let isTranscribing = $state(false);
|
||||
let transcriptionResult = $state<string | null>(null);
|
||||
let error = $state<string | null>(null);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let isDragging = $state(false);
|
||||
let fileInput = $state<HTMLInputElement | null>(null);
|
||||
let copied = $state(false);
|
||||
|
||||
const ACCEPTED_FORMATS = ['.mp3', '.wav'];
|
||||
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
|
||||
let canTranscribe = $derived(selectedFile !== null && $selectedModelStore !== "" && !isTranscribing);
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.audioTranscribing.set(isTranscribing);
|
||||
});
|
||||
|
||||
function validateFile(file: File): { valid: boolean; error?: string } {
|
||||
const ext = '.' + file.name.split('.').pop()?.toLowerCase();
|
||||
|
||||
if (!ACCEPTED_FORMATS.includes(ext)) {
|
||||
return { valid: false, error: 'Invalid file type. Accepted: MP3, WAV' };
|
||||
}
|
||||
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
return { valid: false, error: 'File too large. Maximum: 25MB' };
|
||||
}
|
||||
|
||||
return { valid: true };
|
||||
}
|
||||
|
||||
function handleFileSelect(event: Event) {
|
||||
const target = event.target as HTMLInputElement;
|
||||
const file = target.files?.[0];
|
||||
if (file) {
|
||||
const validation = validateFile(file);
|
||||
if (validation.valid) {
|
||||
selectedFile = file;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
} else {
|
||||
error = validation.error || "Invalid file";
|
||||
selectedFile = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
isDragging = true;
|
||||
}
|
||||
|
||||
function handleDragLeave() {
|
||||
isDragging = false;
|
||||
}
|
||||
|
||||
function handleDrop(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
isDragging = false;
|
||||
|
||||
const file = event.dataTransfer?.files[0];
|
||||
if (file) {
|
||||
const validation = validateFile(file);
|
||||
if (validation.valid) {
|
||||
selectedFile = file;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
} else {
|
||||
error = validation.error || "Invalid file";
|
||||
selectedFile = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function transcribe() {
|
||||
if (!selectedFile || !$selectedModelStore || isTranscribing) return;
|
||||
|
||||
isTranscribing = true;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
const response = await transcribeAudio(
|
||||
$selectedModelStore,
|
||||
selectedFile,
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
transcriptionResult = response.text;
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// User cancelled
|
||||
} else {
|
||||
error = err instanceof Error ? err.message : "An error occurred";
|
||||
}
|
||||
} finally {
|
||||
isTranscribing = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function cancelTranscription() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function clearAll() {
|
||||
selectedFile = null;
|
||||
transcriptionResult = null;
|
||||
error = null;
|
||||
if (fileInput) {
|
||||
fileInput.value = '';
|
||||
}
|
||||
}
|
||||
|
||||
function copyToClipboard() {
|
||||
if (transcriptionResult) {
|
||||
navigator.clipboard.writeText(transcriptionResult);
|
||||
copied = true;
|
||||
setTimeout(() => {
|
||||
copied = false;
|
||||
}, 2000);
|
||||
}
|
||||
}
|
||||
|
||||
function formatFileSize(bytes: number): string {
|
||||
if (bytes < 1024) return bytes + ' B';
|
||||
if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB';
|
||||
return (bytes / (1024 * 1024)).toFixed(1) + ' MB';
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} />
|
||||
</div>
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to transcribe audio.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- File upload / Result display area -->
|
||||
<div class="flex-1 overflow-auto mb-4 flex items-center justify-center bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
{#if isTranscribing}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
|
||||
<p>Transcribing audio...</p>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="text-center text-red-500 p-4">
|
||||
<p class="font-medium">Error</p>
|
||||
<p class="text-sm mt-1">{error}</p>
|
||||
</div>
|
||||
{:else if transcriptionResult}
|
||||
<div class="w-full h-full flex flex-col p-4">
|
||||
<div class="flex justify-between items-center mb-2">
|
||||
<h3 class="font-medium">Transcription Result</h3>
|
||||
<button
|
||||
class="btn btn-sm"
|
||||
onclick={copyToClipboard}
|
||||
title={copied ? 'Copied!' : 'Copy to clipboard'}
|
||||
>
|
||||
{#if copied}
|
||||
<svg class="w-5 h-5 text-green-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"></path>
|
||||
</svg>
|
||||
{:else}
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div class="flex-1 overflow-auto p-3 rounded border border-gray-200 dark:border-white/10 bg-background whitespace-pre-wrap">
|
||||
{transcriptionResult}
|
||||
</div>
|
||||
</div>
|
||||
{:else if selectedFile}
|
||||
<div class="text-center text-txtsecondary p-4">
|
||||
<p class="font-medium mb-2">File Selected</p>
|
||||
<p class="text-sm">{selectedFile.name}</p>
|
||||
<p class="text-xs mt-1">{formatFileSize(selectedFile.size)}</p>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
role="region"
|
||||
aria-label="Audio file drop zone"
|
||||
class="w-full h-full flex items-center justify-center text-center text-txtsecondary p-8 {isDragging ? 'bg-primary/10' : ''}"
|
||||
ondragover={handleDragOver}
|
||||
ondragleave={handleDragLeave}
|
||||
ondrop={handleDrop}
|
||||
>
|
||||
<div>
|
||||
<p class="mb-2">Drag and drop an audio file here</p>
|
||||
<p class="text-sm">or use the Browse button below</p>
|
||||
<p class="text-xs mt-4">Accepted formats: MP3, WAV (max 25MB)</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- File input and transcribe button -->
|
||||
<div class="shrink-0 flex gap-2">
|
||||
<input
|
||||
type="file"
|
||||
accept=".mp3,.wav"
|
||||
class="hidden"
|
||||
onchange={handleFileSelect}
|
||||
bind:this={fileInput}
|
||||
/>
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => fileInput?.click()}
|
||||
disabled={isTranscribing}
|
||||
>
|
||||
Browse Files
|
||||
</button>
|
||||
<div class="flex-1"></div>
|
||||
{#if isTranscribing}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancelTranscription}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={transcribe}
|
||||
disabled={!canTranscribe}
|
||||
>
|
||||
Transcribe
|
||||
</button>
|
||||
<button
|
||||
class="btn"
|
||||
onclick={clearAll}
|
||||
disabled={!selectedFile && !transcriptionResult && !error}
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,466 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { streamChatCompletion } from "../../lib/chatApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import type { ChatMessage, ContentPart } from "../../lib/types";
|
||||
import ChatMessageComponent from "./ChatMessage.svelte";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
import ExpandableTextarea from "./ExpandableTextarea.svelte";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-selected-model", "");
|
||||
const systemPromptStore = persistentStore<string>("playground-system-prompt", "");
|
||||
const temperatureStore = persistentStore<number>("playground-temperature", 0.7);
|
||||
|
||||
function loadMessages(): ChatMessage[] {
|
||||
try {
|
||||
const saved = localStorage.getItem("playground-messages");
|
||||
return saved ? JSON.parse(saved) : [];
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
let messages = $state<ChatMessage[]>(loadMessages());
|
||||
let userInput = $state("");
|
||||
let isStreaming = $state(false);
|
||||
let isReasoning = $state(false);
|
||||
let reasoningStartTime = $state<number>(0);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let messagesContainer: HTMLDivElement | undefined = $state();
|
||||
let showSettings = $state(false);
|
||||
let attachedImages = $state<string[]>([]);
|
||||
let fileInput = $state<HTMLInputElement | null>(null);
|
||||
let imageError = $state<string | null>(null);
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
let userScrolledUp = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.chatStreaming.set(isStreaming);
|
||||
});
|
||||
|
||||
function handleMessagesScroll() {
|
||||
if (!messagesContainer) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = messagesContainer;
|
||||
// Consider "at bottom" if within 40px of the bottom
|
||||
userScrolledUp = scrollHeight - scrollTop - clientHeight > 40;
|
||||
}
|
||||
|
||||
// Auto-scroll when messages change — skip if user scrolled up
|
||||
$effect(() => {
|
||||
if (messages.length > 0 && messagesContainer && !userScrolledUp) {
|
||||
messagesContainer.scrollTo({
|
||||
top: messagesContainer.scrollHeight,
|
||||
behavior: isStreaming ? "instant" : "smooth",
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Persist messages to localStorage (throttled to once per 2s)
|
||||
let lastSaveTime = 0;
|
||||
$effect(() => {
|
||||
const json = JSON.stringify(messages);
|
||||
const elapsed = Date.now() - lastSaveTime;
|
||||
const save = () => {
|
||||
try { localStorage.setItem("playground-messages", json); } catch {}
|
||||
lastSaveTime = Date.now();
|
||||
};
|
||||
if (elapsed >= 2000) {
|
||||
save();
|
||||
return;
|
||||
}
|
||||
const timer = setTimeout(save, 2000 - elapsed);
|
||||
return () => clearTimeout(timer);
|
||||
});
|
||||
|
||||
async function sendMessage() {
|
||||
const trimmedInput = userInput.trim();
|
||||
if ((!trimmedInput && attachedImages.length === 0) || !$selectedModelStore || isStreaming) return;
|
||||
|
||||
userScrolledUp = false;
|
||||
|
||||
// Build message content (multimodal if images attached)
|
||||
let content: string | ContentPart[];
|
||||
if (attachedImages.length > 0) {
|
||||
const parts: ContentPart[] = [];
|
||||
if (trimmedInput) {
|
||||
parts.push({ type: "text", text: trimmedInput });
|
||||
}
|
||||
for (const url of attachedImages) {
|
||||
parts.push({ type: "image_url", image_url: { url } });
|
||||
}
|
||||
content = parts;
|
||||
} else {
|
||||
content = trimmedInput;
|
||||
}
|
||||
|
||||
// Add user message
|
||||
messages = [...messages, { role: "user", content }];
|
||||
userInput = "";
|
||||
attachedImages = [];
|
||||
imageError = null;
|
||||
|
||||
// Generate response from the new user message
|
||||
await regenerateFromIndex(messages.length - 1);
|
||||
}
|
||||
|
||||
function cancelStreaming() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function newChat() {
|
||||
if (isStreaming) {
|
||||
cancelStreaming();
|
||||
}
|
||||
messages = [];
|
||||
isReasoning = false;
|
||||
reasoningStartTime = 0;
|
||||
}
|
||||
|
||||
async function regenerateFromIndex(idx: number) {
|
||||
// Remove all messages after the edited user message
|
||||
messages = messages.slice(0, idx + 1);
|
||||
|
||||
// Add empty assistant message for the new response
|
||||
messages = [...messages, { role: "assistant", content: "" }];
|
||||
|
||||
isStreaming = true;
|
||||
isReasoning = false;
|
||||
reasoningStartTime = 0;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
// Build messages array with optional system prompt
|
||||
const apiMessages: ChatMessage[] = [];
|
||||
if ($systemPromptStore.trim()) {
|
||||
apiMessages.push({ role: "system", content: $systemPromptStore.trim() });
|
||||
}
|
||||
apiMessages.push(...messages.slice(0, -1)); // Add all messages except the empty assistant one
|
||||
|
||||
const stream = streamChatCompletion(
|
||||
$selectedModelStore,
|
||||
apiMessages,
|
||||
abortController.signal,
|
||||
{ temperature: $temperatureStore }
|
||||
);
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.done) break;
|
||||
|
||||
// Handle reasoning content
|
||||
if (chunk.reasoning_content) {
|
||||
// Start timing on first reasoning content
|
||||
if (!isReasoning) {
|
||||
isReasoning = true;
|
||||
reasoningStartTime = Date.now();
|
||||
}
|
||||
|
||||
// Update the last message with reasoning content
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoning_content: (msg.reasoning_content || "") + chunk.reasoning_content }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
|
||||
// Handle regular content - end reasoning phase when we get content
|
||||
if (chunk.content) {
|
||||
if (isReasoning) {
|
||||
// Calculate reasoning time
|
||||
const reasoningTimeMs = Date.now() - reasoningStartTime;
|
||||
isReasoning = false;
|
||||
|
||||
// Update message with reasoning time
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoningTimeMs }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
|
||||
// Update the last message (assistant) with new content
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, content: msg.content + chunk.content }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
// User cancelled, keep partial response
|
||||
// If we were still reasoning, record the time
|
||||
if (isReasoning && reasoningStartTime > 0) {
|
||||
const reasoningTimeMs = Date.now() - reasoningStartTime;
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoningTimeMs }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Show error in the assistant message
|
||||
const errorMessage = error instanceof Error ? error.message : "An error occurred";
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, content: msg.content + `\n\n**Error:** ${errorMessage}` }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
isStreaming = false;
|
||||
isReasoning = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
async function editMessage(idx: number, newContent: string) {
|
||||
if (isStreaming || !$selectedModelStore) return;
|
||||
|
||||
// Update the user message at the specified index
|
||||
messages = messages.map((msg, i) =>
|
||||
i === idx ? { ...msg, content: newContent } : msg
|
||||
);
|
||||
|
||||
// Trigger a new chat request with the updated messages
|
||||
await regenerateFromIndex(idx);
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
sendMessage();
|
||||
}
|
||||
}
|
||||
|
||||
const ACCEPTED_IMAGE_FORMATS = ["image/jpeg", "image/png", "image/gif", "image/webp"];
|
||||
const MAX_IMAGE_SIZE = 20 * 1024 * 1024; // 20MB
|
||||
const MAX_IMAGES_PER_MESSAGE = 5;
|
||||
|
||||
function validateImageFile(file: File): string | null {
|
||||
if (!ACCEPTED_IMAGE_FORMATS.includes(file.type)) {
|
||||
return `Invalid file type: ${file.type}. Accepted formats: JPG, PNG, GIF, WEBP`;
|
||||
}
|
||||
if (file.size > MAX_IMAGE_SIZE) {
|
||||
return `File too large: ${(file.size / 1024 / 1024).toFixed(1)}MB. Maximum size: 20MB`;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function fileToDataUrl(file: File): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => resolve(reader.result as string);
|
||||
reader.onerror = () => reject(new Error("Failed to read file"));
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function processImageFiles(files: File[]): Promise<void> {
|
||||
imageError = null;
|
||||
|
||||
if (attachedImages.length + files.length > MAX_IMAGES_PER_MESSAGE) {
|
||||
imageError = `Maximum ${MAX_IMAGES_PER_MESSAGE} images per message`;
|
||||
return;
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
const error = validateImageFile(file);
|
||||
if (error) {
|
||||
imageError = error;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const dataUrls = await Promise.all(files.map(fileToDataUrl));
|
||||
attachedImages = [...attachedImages, ...dataUrls];
|
||||
} catch (error) {
|
||||
imageError = error instanceof Error ? error.message : "Failed to process images";
|
||||
}
|
||||
}
|
||||
|
||||
function handleImageSelect(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
if (input.files && input.files.length > 0) {
|
||||
processImageFiles(Array.from(input.files));
|
||||
}
|
||||
// Reset the input so the same file can be selected again
|
||||
input.value = "";
|
||||
}
|
||||
|
||||
function removeImage(idx: number) {
|
||||
attachedImages = attachedImages.filter((_, i) => i !== idx);
|
||||
imageError = null;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector and controls -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a model..." disabled={isStreaming} />
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => (showSettings = !showSettings)}
|
||||
title="Settings"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M8.34 1.804A1 1 0 0 1 9.32 1h1.36a1 1 0 0 1 .98.804l.295 1.473c.497.144.971.342 1.416.587l1.25-.834a1 1 0 0 1 1.262.125l.962.962a1 1 0 0 1 .125 1.262l-.834 1.25c.245.445.443.919.587 1.416l1.473.295a1 1 0 0 1 .804.98v1.36a1 1 0 0 1-.804.98l-1.473.295a6.95 6.95 0 0 1-.587 1.416l.834 1.25a1 1 0 0 1-.125 1.262l-.962.962a1 1 0 0 1-1.262.125l-1.25-.834a6.953 6.953 0 0 1-1.416.587l-.295 1.473a1 1 0 0 1-.98.804H9.32a1 1 0 0 1-.98-.804l-.295-1.473a6.957 6.957 0 0 1-1.416-.587l-1.25.834a1 1 0 0 1-1.262-.125l-.962-.962a1 1 0 0 1-.125-1.262l.834-1.25a6.957 6.957 0 0 1-.587-1.416l-1.473-.295A1 1 0 0 1 1 10.68V9.32a1 1 0 0 1 .804-.98l1.473-.295c.144-.497.342-.971.587-1.416l-.834-1.25a1 1 0 0 1 .125-1.262l.962-.962A1 1 0 0 1 5.38 3.03l1.25.834a6.957 6.957 0 0 1 1.416-.587l.294-1.473ZM13 10a3 3 0 1 1-6 0 3 3 0 0 1 6 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
<button class="btn" onclick={newChat} disabled={messages.length === 0 && !isStreaming}>
|
||||
New Chat
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Settings panel -->
|
||||
{#if showSettings}
|
||||
<div class="shrink-0 mb-4 p-4 bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
<div class="mb-4">
|
||||
<label class="block text-sm font-medium mb-1" for="system-prompt">System Prompt</label>
|
||||
<textarea
|
||||
id="system-prompt"
|
||||
class="w-full px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-card focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
placeholder="You are a helpful assistant..."
|
||||
rows="3"
|
||||
bind:value={$systemPromptStore}
|
||||
disabled={isStreaming}
|
||||
></textarea>
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-sm font-medium mb-1" for="temperature">
|
||||
Temperature: {$temperatureStore.toFixed(2)}
|
||||
</label>
|
||||
<input
|
||||
id="temperature"
|
||||
type="range"
|
||||
min="0"
|
||||
max="2"
|
||||
step="0.05"
|
||||
class="w-full"
|
||||
bind:value={$temperatureStore}
|
||||
disabled={isStreaming}
|
||||
/>
|
||||
<div class="flex justify-between text-xs text-txtsecondary mt-1">
|
||||
<span>Precise (0)</span>
|
||||
<span>Creative (2)</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to start chatting.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Messages area -->
|
||||
<div
|
||||
class="flex-1 overflow-y-auto mb-4 px-2"
|
||||
bind:this={messagesContainer}
|
||||
onscroll={handleMessagesScroll}
|
||||
>
|
||||
{#if messages.length === 0}
|
||||
<div class="h-full flex items-center justify-center text-txtsecondary">
|
||||
<p>Start a conversation by typing a message below.</p>
|
||||
</div>
|
||||
{:else}
|
||||
{#each messages as message, idx (idx)}
|
||||
<ChatMessageComponent
|
||||
role={message.role}
|
||||
content={message.content}
|
||||
reasoning_content={message.reasoning_content}
|
||||
reasoningTimeMs={message.reasoningTimeMs}
|
||||
isStreaming={isStreaming && idx === messages.length - 1 && message.role === "assistant"}
|
||||
isReasoning={isReasoning && idx === messages.length - 1 && message.role === "assistant"}
|
||||
onEdit={message.role === "user" ? (newContent) => editMessage(idx, newContent) : undefined}
|
||||
onRegenerate={message.role === "assistant" && idx > 0 && messages[idx - 1].role === "user"
|
||||
? () => regenerateFromIndex(idx - 1)
|
||||
: undefined}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Input area -->
|
||||
<div class="shrink-0">
|
||||
<!-- Image preview strip -->
|
||||
{#if attachedImages.length > 0}
|
||||
<div class="mb-2 flex flex-wrap gap-2">
|
||||
{#each attachedImages as imageUrl, idx (idx)}
|
||||
<div class="relative group">
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Attached image {idx + 1}"
|
||||
class="w-20 h-20 object-cover rounded border border-gray-200 dark:border-white/10"
|
||||
/>
|
||||
<button
|
||||
class="absolute -top-2 -right-2 bg-red-500 text-white rounded-full w-6 h-6 flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onclick={() => removeImage(idx)}
|
||||
title="Remove image"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Error message -->
|
||||
{#if imageError}
|
||||
<div class="mb-2 p-2 bg-red-100 dark:bg-red-900/20 text-red-700 dark:text-red-400 rounded text-sm">
|
||||
{imageError}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="flex gap-2">
|
||||
<!-- Hidden file input -->
|
||||
<input
|
||||
type="file"
|
||||
accept=".jpg,.jpeg,.png,.gif,.webp"
|
||||
multiple
|
||||
class="hidden"
|
||||
bind:this={fileInput}
|
||||
onchange={handleImageSelect}
|
||||
/>
|
||||
|
||||
<ExpandableTextarea
|
||||
bind:value={userInput}
|
||||
placeholder="Type a message..."
|
||||
rows={3}
|
||||
onkeydown={handleKeyDown}
|
||||
disabled={isStreaming || !$selectedModelStore}
|
||||
/>
|
||||
<div class="flex flex-col gap-2">
|
||||
{#if isStreaming}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancelStreaming}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => fileInput?.click()}
|
||||
disabled={isStreaming || !$selectedModelStore}
|
||||
title="Attach image"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M1 5.25A2.25 2.25 0 0 1 3.25 3h13.5A2.25 2.25 0 0 1 19 5.25v9.5A2.25 2.25 0 0 1 16.75 17H3.25A2.25 2.25 0 0 1 1 14.75v-9.5Zm1.5 5.81v3.69c0 .414.336.75.75.75h13.5a.75.75 0 0 0 .75-.75v-2.69l-2.22-2.219a.75.75 0 0 0-1.06 0l-1.91 1.909.47.47a.75.75 0 1 1-1.06 1.06L6.53 8.091a.75.75 0 0 0-1.06 0l-2.97 2.97ZM12 7a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={sendMessage}
|
||||
disabled={(!userInput.trim() && attachedImages.length === 0) || !$selectedModelStore}
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,398 @@
|
||||
<script lang="ts">
|
||||
import { renderMarkdown, escapeHtml, renderStreamingMarkdown, createStreamingCache } from "../../lib/markdown";
|
||||
import type { RenderedBlock } from "../../lib/markdown";
|
||||
import { Copy, Check, Pencil, X, Save, RefreshCw, ChevronDown, ChevronRight, Brain, Code } from "lucide-svelte";
|
||||
import { getTextContent, getImageUrls } from "../../lib/types";
|
||||
import type { ContentPart } from "../../lib/types";
|
||||
|
||||
interface Props {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string | ContentPart[];
|
||||
reasoning_content?: string;
|
||||
reasoningTimeMs?: number;
|
||||
isStreaming?: boolean;
|
||||
isReasoning?: boolean;
|
||||
onEdit?: (newContent: string) => void;
|
||||
onRegenerate?: () => void;
|
||||
}
|
||||
|
||||
let { role, content, reasoning_content = "", reasoningTimeMs = 0, isStreaming = false, isReasoning = false, onEdit, onRegenerate }: Props = $props();
|
||||
|
||||
let textContent = $derived(getTextContent(content));
|
||||
let imageUrls = $derived(getImageUrls(content));
|
||||
let hasImages = $derived(imageUrls.length > 0);
|
||||
let canEdit = $derived(onEdit !== undefined && !hasImages);
|
||||
|
||||
let streamingCache = createStreamingCache();
|
||||
let renderedParts = $derived.by(() => {
|
||||
if (role !== "assistant") {
|
||||
return { blocks: [{ id: -1, html: escapeHtml(textContent).replace(/\n/g, '<br>') }] as RenderedBlock[], pendingHtml: "" };
|
||||
}
|
||||
if (!isStreaming) {
|
||||
streamingCache = createStreamingCache();
|
||||
return { blocks: [{ id: -1, html: renderMarkdown(textContent) }] as RenderedBlock[], pendingHtml: "" };
|
||||
}
|
||||
return renderStreamingMarkdown(textContent, streamingCache);
|
||||
});
|
||||
let copied = $state(false);
|
||||
let showRaw = $state(false);
|
||||
let isEditing = $state(false);
|
||||
let editContent = $state("");
|
||||
let showReasoning = $state(false);
|
||||
let modalImageUrl = $state<string | null>(null);
|
||||
|
||||
function formatDuration(ms: number): string {
|
||||
if (ms < 1000) {
|
||||
return `${ms.toFixed(0)}ms`;
|
||||
}
|
||||
return `${(ms / 1000).toFixed(1)}s`;
|
||||
}
|
||||
|
||||
async function copyToClipboard() {
|
||||
try {
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(textContent);
|
||||
} else {
|
||||
// Fallback for non-secure contexts (HTTP)
|
||||
const textarea = document.createElement("textarea");
|
||||
textarea.value = textContent;
|
||||
textarea.style.position = "fixed";
|
||||
textarea.style.left = "-9999px";
|
||||
document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
document.execCommand("copy");
|
||||
document.body.removeChild(textarea);
|
||||
}
|
||||
copied = true;
|
||||
setTimeout(() => (copied = false), 2000);
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
}
|
||||
}
|
||||
|
||||
function startEdit() {
|
||||
editContent = textContent;
|
||||
isEditing = true;
|
||||
}
|
||||
|
||||
function cancelEdit() {
|
||||
isEditing = false;
|
||||
editContent = "";
|
||||
}
|
||||
|
||||
function saveEdit() {
|
||||
if (onEdit && editContent.trim() !== textContent) {
|
||||
onEdit(editContent.trim());
|
||||
}
|
||||
isEditing = false;
|
||||
editContent = "";
|
||||
}
|
||||
|
||||
function openModal(imageUrl: string) {
|
||||
modalImageUrl = imageUrl;
|
||||
document.body.style.overflow = "hidden";
|
||||
}
|
||||
|
||||
function closeModal(event?: MouseEvent) {
|
||||
// Only close if clicking the background, not the image
|
||||
if (event && event.target !== event.currentTarget) {
|
||||
return;
|
||||
}
|
||||
modalImageUrl = null;
|
||||
document.body.style.overflow = "";
|
||||
}
|
||||
|
||||
function handleModalKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Escape") {
|
||||
closeModal();
|
||||
}
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
saveEdit();
|
||||
} else if (event.key === "Escape") {
|
||||
cancelEdit();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex {role === 'user' ? 'justify-end' : 'justify-start'} mb-4">
|
||||
<div
|
||||
class="relative group rounded-lg px-4 py-2 {role === 'user'
|
||||
? 'max-w-[85%] bg-primary text-btn-primary-text'
|
||||
: 'w-full sm:w-4/5 bg-surface border border-gray-200 dark:border-white/10'}"
|
||||
>
|
||||
{#if role === "assistant"}
|
||||
{#if reasoning_content || isReasoning}
|
||||
<div class="mb-3 border border-gray-200 dark:border-white/10 rounded overflow-hidden">
|
||||
<button
|
||||
class="w-full flex items-center gap-2 px-3 py-2 bg-gray-50 dark:bg-white/5 hover:bg-gray-100 dark:hover:bg-white/10 transition-colors text-sm"
|
||||
onclick={() => showReasoning = !showReasoning}
|
||||
>
|
||||
{#if showReasoning}
|
||||
<ChevronDown class="w-4 h-4" />
|
||||
{:else}
|
||||
<ChevronRight class="w-4 h-4" />
|
||||
{/if}
|
||||
<Brain class="w-4 h-4" />
|
||||
<span class="font-medium">Reasoning</span>
|
||||
<span class="text-txtsecondary ml-2">
|
||||
({reasoning_content.length} chars{#if !isReasoning && reasoningTimeMs > 0}, {formatDuration(reasoningTimeMs)}{/if})
|
||||
</span>
|
||||
{#if isReasoning}
|
||||
<span class="ml-auto flex items-center gap-1 text-txtsecondary">
|
||||
<span class="w-1.5 h-1.5 bg-primary rounded-full animate-pulse"></span>
|
||||
reasoning...
|
||||
</span>
|
||||
{/if}
|
||||
</button>
|
||||
{#if showReasoning}
|
||||
<div class="px-3 py-2 bg-gray-50/50 dark:bg-white/[0.02] text-sm text-txtsecondary whitespace-pre-wrap font-mono">
|
||||
{reasoning_content}{#if isReasoning}<span class="inline-block w-1.5 h-4 bg-current animate-pulse ml-0.5"></span>{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{#if hasImages}
|
||||
<div class="mb-3 flex flex-wrap gap-2">
|
||||
{#each imageUrls as imageUrl, idx (idx)}
|
||||
<button
|
||||
onclick={() => openModal(imageUrl)}
|
||||
class="cursor-pointer rounded border border-gray-200 dark:border-white/10 hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Image {idx + 1}"
|
||||
class="max-h-64 rounded"
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{#if showRaw}
|
||||
<div class="whitespace-pre-wrap font-mono text-sm">{textContent}</div>
|
||||
{:else}
|
||||
<div class="prose prose-sm dark:prose-invert max-w-none">
|
||||
{#each renderedParts.blocks as block (block.id)}
|
||||
{@html block.html}
|
||||
{/each}
|
||||
{@html renderedParts.pendingHtml}
|
||||
{#if isStreaming && !isReasoning}
|
||||
<span class="inline-block w-2 h-4 bg-current animate-pulse ml-0.5"></span>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{#if !isStreaming}
|
||||
<div class="flex gap-1 mt-2 pt-1 border-t border-gray-200 dark:border-white/10">
|
||||
{#if onRegenerate}
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 text-txtsecondary"
|
||||
onclick={onRegenerate}
|
||||
title="Regenerate response"
|
||||
>
|
||||
<RefreshCw class="w-4 h-4" />
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 text-txtsecondary"
|
||||
onclick={copyToClipboard}
|
||||
title={copied ? "Copied!" : "Copy to clipboard"}
|
||||
>
|
||||
{#if copied}
|
||||
<Check class="w-4 h-4 text-green-500" />
|
||||
{:else}
|
||||
<Copy class="w-4 h-4" />
|
||||
{/if}
|
||||
</button>
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 {showRaw ? 'text-primary' : 'text-txtsecondary'}"
|
||||
onclick={() => showRaw = !showRaw}
|
||||
title={showRaw ? "Show rendered" : "Show raw"}
|
||||
>
|
||||
<Code class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
{:else}
|
||||
{#if isEditing}
|
||||
<div class="flex flex-col gap-2 min-w-[300px]">
|
||||
<textarea
|
||||
class="w-full px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface text-txtmain focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
rows="3"
|
||||
bind:value={editContent}
|
||||
onkeydown={handleKeyDown}
|
||||
></textarea>
|
||||
<div class="flex justify-end gap-2">
|
||||
<button
|
||||
class="p-1.5 rounded hover:bg-white/20"
|
||||
onclick={cancelEdit}
|
||||
title="Cancel"
|
||||
>
|
||||
<X class="w-4 h-4" />
|
||||
</button>
|
||||
<button
|
||||
class="p-1.5 rounded hover:bg-white/20"
|
||||
onclick={saveEdit}
|
||||
title="Save"
|
||||
>
|
||||
<Save class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
{#if hasImages}
|
||||
<div class="mb-2 flex flex-wrap gap-2">
|
||||
{#each imageUrls as imageUrl, idx (idx)}
|
||||
<button
|
||||
onclick={() => openModal(imageUrl)}
|
||||
class="cursor-pointer rounded border border-white/20 hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Image {idx + 1}"
|
||||
class="max-w-[200px] rounded"
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="whitespace-pre-wrap pr-8">{textContent}</div>
|
||||
{#if canEdit}
|
||||
<button
|
||||
class="absolute top-2 right-2 p-1.5 rounded-lg opacity-0 group-hover:opacity-100 transition-opacity bg-white/20 hover:bg-white/30 shadow-sm"
|
||||
onclick={startEdit}
|
||||
title="Edit message"
|
||||
>
|
||||
<Pencil class="w-4 h-4" />
|
||||
</button>
|
||||
{/if}
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Full-size image modal -->
|
||||
{#if modalImageUrl}
|
||||
<div
|
||||
class="fixed inset-0 z-50 flex items-center justify-center bg-black/80 p-4"
|
||||
onclick={(e) => closeModal(e)}
|
||||
onkeydown={handleModalKeyDown}
|
||||
role="button"
|
||||
tabindex="-1"
|
||||
>
|
||||
<button
|
||||
class="absolute top-4 right-4 p-2 rounded-lg bg-white/10 hover:bg-white/20 text-white transition-colors"
|
||||
onclick={() => closeModal()}
|
||||
title="Close"
|
||||
>
|
||||
<X class="w-6 h-6" />
|
||||
</button>
|
||||
<img
|
||||
src={modalImageUrl}
|
||||
alt=""
|
||||
class="max-w-full max-h-full rounded pointer-events-none"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.prose :global(pre) {
|
||||
background-color: var(--color-surface);
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
border-radius: 0.375rem;
|
||||
padding: 0.75rem;
|
||||
overflow-x: auto;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(code) {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
.prose :global(pre code) {
|
||||
background: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.prose :global(code:not(pre code)) {
|
||||
background-color: var(--color-surface);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
}
|
||||
|
||||
.prose :global(p) {
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(p:first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.prose :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.prose :global(ul),
|
||||
.prose :global(ol) {
|
||||
margin: 0.5rem 0;
|
||||
padding-left: 1.5rem;
|
||||
}
|
||||
|
||||
.prose :global(li) {
|
||||
margin: 0.25rem 0;
|
||||
}
|
||||
|
||||
.prose :global(h1),
|
||||
.prose :global(h2),
|
||||
.prose :global(h3),
|
||||
.prose :global(h4) {
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.prose :global(h1:first-child),
|
||||
.prose :global(h2:first-child),
|
||||
.prose :global(h3:first-child),
|
||||
.prose :global(h4:first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.prose :global(blockquote) {
|
||||
border-left: 3px solid var(--color-primary);
|
||||
padding-left: 1rem;
|
||||
margin: 0.5rem 0;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.prose :global(a) {
|
||||
color: var(--color-primary);
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.prose :global(table) {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(th),
|
||||
.prose :global(td) {
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.prose :global(th) {
|
||||
background-color: var(--color-surface);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
/* Highlight.js theme overrides for dark mode */
|
||||
:global(.dark) .prose :global(.hljs) {
|
||||
background: transparent;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,121 @@
|
||||
<script lang="ts">
|
||||
import { untrack } from "svelte";
|
||||
import { Maximize2, X } from "lucide-svelte";
|
||||
|
||||
interface Props {
|
||||
value: string;
|
||||
placeholder?: string;
|
||||
rows?: number;
|
||||
disabled?: boolean;
|
||||
onkeydown?: (event: KeyboardEvent) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
value = $bindable(),
|
||||
placeholder = "",
|
||||
rows = 3,
|
||||
disabled = false,
|
||||
onkeydown,
|
||||
}: Props = $props();
|
||||
|
||||
let isExpanded = $state(false);
|
||||
let expandedValue = $state("");
|
||||
let expandedTextarea: HTMLTextAreaElement | undefined = $state();
|
||||
|
||||
function openExpanded() {
|
||||
expandedValue = value;
|
||||
isExpanded = true;
|
||||
}
|
||||
|
||||
function closeExpanded() {
|
||||
isExpanded = false;
|
||||
}
|
||||
|
||||
function saveExpanded() {
|
||||
value = expandedValue;
|
||||
isExpanded = false;
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Escape") {
|
||||
closeExpanded();
|
||||
}
|
||||
}
|
||||
|
||||
// Focus the textarea when expanded view opens
|
||||
$effect(() => {
|
||||
if (isExpanded && expandedTextarea) {
|
||||
expandedTextarea.focus();
|
||||
const len = untrack(() => expandedValue.length);
|
||||
expandedTextarea.setSelectionRange(len, len);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex-1 relative group flex items-stretch min-h-0">
|
||||
<textarea
|
||||
class="w-full px-3 py-2 pr-10 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-inset focus:ring-primary resize-none"
|
||||
{placeholder}
|
||||
{rows}
|
||||
bind:value
|
||||
{onkeydown}
|
||||
{disabled}
|
||||
></textarea>
|
||||
<button
|
||||
class="absolute top-2 right-2 p-1.5 rounded-lg opacity-60 md:opacity-0 group-hover:opacity-100 transition-opacity bg-surface/90 hover:bg-surface border border-gray-200 dark:border-white/10 shadow-sm"
|
||||
onclick={openExpanded}
|
||||
title="Expand to edit"
|
||||
type="button"
|
||||
{disabled}
|
||||
>
|
||||
<Maximize2 class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{#if isExpanded}
|
||||
<div class="fixed inset-0 z-50 flex items-center justify-center bg-black/50 p-4">
|
||||
<div class="w-full max-w-4xl h-[80vh] flex flex-col bg-surface rounded-lg shadow-xl border border-gray-200 dark:border-white/10">
|
||||
<!-- Header -->
|
||||
<div class="flex justify-between items-center p-4 border-b border-gray-200 dark:border-white/10">
|
||||
<h3 class="font-medium">Edit Text</h3>
|
||||
<button
|
||||
class="p-1.5 rounded-lg hover:bg-gray-100 dark:hover:bg-white/10"
|
||||
onclick={closeExpanded}
|
||||
title="Close"
|
||||
type="button"
|
||||
>
|
||||
<X class="w-5 h-5" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Textarea -->
|
||||
<div class="flex-1 p-4">
|
||||
<textarea
|
||||
bind:this={expandedTextarea}
|
||||
class="w-full h-full px-4 py-3 rounded border border-gray-200 dark:border-white/10 bg-card focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
placeholder={placeholder}
|
||||
bind:value={expandedValue}
|
||||
onkeydown={handleKeyDown}
|
||||
></textarea>
|
||||
</div>
|
||||
|
||||
<!-- Footer -->
|
||||
<div class="flex justify-end gap-2 p-4 border-t border-gray-200 dark:border-white/10">
|
||||
<button
|
||||
class="btn"
|
||||
onclick={closeExpanded}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={saveExpanded}
|
||||
type="button"
|
||||
>
|
||||
Done
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -0,0 +1,234 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { generateImage } from "../../lib/imageApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
import ExpandableTextarea from "./ExpandableTextarea.svelte";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-image-model", "");
|
||||
const selectedSizeStore = persistentStore<string>("playground-image-size", "1024x1024");
|
||||
|
||||
let prompt = $state("");
|
||||
let isGenerating = $state(false);
|
||||
let generatedImage = $state<string | null>(null);
|
||||
let error = $state<string | null>(null);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let showFullscreen = $state(false);
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.imageGenerating.set(isGenerating);
|
||||
});
|
||||
|
||||
async function generate() {
|
||||
const trimmedPrompt = prompt.trim();
|
||||
if (!trimmedPrompt || !$selectedModelStore || isGenerating) return;
|
||||
|
||||
isGenerating = true;
|
||||
error = null;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
const response = await generateImage(
|
||||
$selectedModelStore,
|
||||
trimmedPrompt,
|
||||
$selectedSizeStore,
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
if (response.data && response.data.length > 0) {
|
||||
const imageData = response.data[0];
|
||||
if (imageData.b64_json) {
|
||||
generatedImage = `data:image/png;base64,${imageData.b64_json}`;
|
||||
} else if (imageData.url) {
|
||||
generatedImage = imageData.url;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// User cancelled
|
||||
} else {
|
||||
error = err instanceof Error ? err.message : "An error occurred";
|
||||
}
|
||||
} finally {
|
||||
isGenerating = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function cancelGeneration() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function clearImage() {
|
||||
generatedImage = null;
|
||||
error = null;
|
||||
prompt = "";
|
||||
}
|
||||
|
||||
function downloadImage() {
|
||||
if (!generatedImage) return;
|
||||
|
||||
const link = document.createElement("a");
|
||||
link.href = generatedImage;
|
||||
link.download = `generated-image-${Date.now()}.png`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
}
|
||||
|
||||
function openFullscreen() {
|
||||
showFullscreen = true;
|
||||
}
|
||||
|
||||
function closeFullscreen(event?: MouseEvent) {
|
||||
// Only close if clicking the background, not the image
|
||||
if (event && event.target !== event.currentTarget) {
|
||||
return;
|
||||
}
|
||||
showFullscreen = false;
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
generate();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} />
|
||||
<select
|
||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$selectedSizeStore}
|
||||
disabled={isGenerating}
|
||||
>
|
||||
<optgroup label="Square">
|
||||
<option value="512x512">512x512</option>
|
||||
<option value="1024x1024">1024x1024</option>
|
||||
</optgroup>
|
||||
<optgroup label="Landscape">
|
||||
<option value="1024x768">1024x768 (4:3)</option>
|
||||
<option value="1280x720">1280x720 (16:9)</option>
|
||||
<option value="1792x1024">1792x1024 (SDXL)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Portrait">
|
||||
<option value="768x1024">768x1024 (3:4)</option>
|
||||
<option value="720x1280">720x1280 (9:16)</option>
|
||||
<option value="1024x1792">1024x1792 (SDXL)</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to generate images.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Image display area -->
|
||||
<div class="flex-1 overflow-auto mb-4 flex items-center justify-center bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
{#if isGenerating}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
|
||||
<p>Generating image...</p>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="text-center text-red-500 p-4">
|
||||
<p class="font-medium">Error</p>
|
||||
<p class="text-sm mt-1">{error}</p>
|
||||
</div>
|
||||
{:else if generatedImage}
|
||||
<div class="relative max-w-full max-h-full flex items-center justify-center">
|
||||
<button
|
||||
class="p-0 border-0 bg-transparent cursor-pointer"
|
||||
onclick={openFullscreen}
|
||||
aria-label="View fullscreen"
|
||||
>
|
||||
<img
|
||||
src={generatedImage}
|
||||
alt="AI generated content"
|
||||
class="max-w-full max-h-full object-contain hover:opacity-90 transition-opacity"
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
class="absolute bottom-2 right-2 p-2 bg-black/60 hover:bg-black/80 text-white rounded-full transition-colors"
|
||||
onclick={(e) => { e.stopPropagation(); downloadImage(); }}
|
||||
aria-label="Download image"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<p>Enter a prompt below to generate an image</p>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Prompt input area -->
|
||||
<div class="shrink-0 flex flex-col md:flex-row gap-2">
|
||||
<ExpandableTextarea
|
||||
bind:value={prompt}
|
||||
placeholder="Describe the image you want to generate..."
|
||||
rows={3}
|
||||
onkeydown={handleKeyDown}
|
||||
disabled={isGenerating || !$selectedModelStore}
|
||||
/>
|
||||
<div class="flex flex-row md:flex-col gap-2">
|
||||
{#if isGenerating}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white flex-1 md:flex-none" onclick={cancelGeneration}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90 flex-1 md:flex-none"
|
||||
onclick={generate}
|
||||
disabled={!prompt.trim() || !$selectedModelStore}
|
||||
>
|
||||
Generate
|
||||
</button>
|
||||
<button
|
||||
class="btn flex-1 md:flex-none"
|
||||
onclick={clearImage}
|
||||
disabled={!generatedImage && !error && !prompt.trim()}
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Fullscreen dialog -->
|
||||
{#if showFullscreen && generatedImage}
|
||||
<div
|
||||
class="fixed inset-0 bg-black/90 z-50 flex items-center justify-center p-4"
|
||||
onclick={(e) => closeFullscreen(e)}
|
||||
onkeydown={(e) => e.key === 'Escape' && closeFullscreen()}
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
tabindex="-1"
|
||||
>
|
||||
<button
|
||||
class="absolute top-4 right-4 text-white hover:text-gray-300 text-2xl w-10 h-10 flex items-center justify-center rounded-full hover:bg-white/10 transition-colors"
|
||||
onclick={() => closeFullscreen()}
|
||||
aria-label="Close fullscreen"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
<img
|
||||
src={generatedImage}
|
||||
alt="AI generated content"
|
||||
class="max-w-full max-h-full object-contain pointer-events-none"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -0,0 +1,39 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { groupModels } from "../../lib/modelUtils";
|
||||
|
||||
interface Props {
|
||||
value: string;
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
let { value = $bindable(), placeholder = "Select a model...", disabled = false }: Props = $props();
|
||||
|
||||
let grouped = $derived(groupModels($models));
|
||||
let hasModels = $derived(grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0);
|
||||
</script>
|
||||
|
||||
{#if hasModels}
|
||||
<select
|
||||
class="min-w-0 flex-1 basis-48 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value
|
||||
{disabled}
|
||||
>
|
||||
<option value="">{placeholder}</option>
|
||||
{#if grouped.local.length > 0}
|
||||
<optgroup label="Local">
|
||||
{#each grouped.local as model (model.id)}
|
||||
<option value={model.id}>{model.id}</option>
|
||||
{/each}
|
||||
</optgroup>
|
||||
{/if}
|
||||
{#each Object.entries(grouped.peersByProvider).sort(([a], [b]) => a.localeCompare(b)) as [peerId, peerModels] (peerId)}
|
||||
<optgroup label="Peer: {peerId}">
|
||||
{#each peerModels as model (model.id)}
|
||||
<option value={model.id}>{model.id}</option>
|
||||
{/each}
|
||||
</optgroup>
|
||||
{/each}
|
||||
</select>
|
||||
{/if}
|
||||
@@ -0,0 +1,14 @@
|
||||
<script lang="ts">
|
||||
interface Props {
|
||||
featureName: string;
|
||||
}
|
||||
|
||||
let { featureName }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="flex items-center justify-center h-full">
|
||||
<div class="text-center text-txtsecondary">
|
||||
<p class="text-lg">{featureName}</p>
|
||||
<p class="text-sm mt-2">To be implemented</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,360 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { generateSpeech } from "../../lib/speechApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
import ExpandableTextarea from "./ExpandableTextarea.svelte";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-speech-model", "");
|
||||
const selectedVoiceStore = persistentStore<string>("playground-speech-voice", "coral");
|
||||
const autoPlayStore = persistentStore<boolean>("playground-speech-autoplay", false);
|
||||
|
||||
let inputText = $state("");
|
||||
let isGenerating = $state(false);
|
||||
let generatedAudioUrl = $state<string | null>(null);
|
||||
let generatedVoice = $state<string | null>(null);
|
||||
let generatedTimestamp = $state<Date | null>(null);
|
||||
let error = $state<string | null>(null);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let audioElement = $state<HTMLAudioElement | null>(null);
|
||||
let availableVoices = $state<string[]>(["coral", "alloy", "echo", "fable", "onyx", "nova", "shimmer"]);
|
||||
let isLoadingVoices = $state(false);
|
||||
|
||||
const defaultVoices = ["coral", "alloy", "echo", "fable", "onyx", "nova", "shimmer"];
|
||||
const CACHE_KEY = "playground-speech-voices-cache";
|
||||
|
||||
function getVoicesCache(): Record<string, string[]> {
|
||||
if (typeof window === "undefined") return {};
|
||||
try {
|
||||
const saved = localStorage.getItem(CACHE_KEY);
|
||||
return saved ? JSON.parse(saved) : {};
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function saveVoicesCache(cache: Record<string, string[]>) {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
localStorage.setItem(CACHE_KEY, JSON.stringify(cache));
|
||||
} catch (e) {
|
||||
console.error("Error saving voices cache", e);
|
||||
}
|
||||
}
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
|
||||
let isInitialLoad = $state(true);
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.speechGenerating.set(isGenerating);
|
||||
});
|
||||
|
||||
// On page load, restore cached voices for the selected model if available
|
||||
$effect(() => {
|
||||
const model = $selectedModelStore;
|
||||
|
||||
if (isInitialLoad) {
|
||||
isInitialLoad = false;
|
||||
// If we have cached voices for this model, use them
|
||||
const cache = getVoicesCache();
|
||||
if (model && cache[model]) {
|
||||
availableVoices = cache[model];
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
async function refreshVoices() {
|
||||
const model = $selectedModelStore;
|
||||
if (!model || isLoadingVoices) return;
|
||||
|
||||
isLoadingVoices = true;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/v1/audio/voices?model=${encodeURIComponent(model)}`);
|
||||
if (!response.ok) {
|
||||
// Fall back to default voices if API call fails
|
||||
availableVoices = defaultVoices;
|
||||
const cache = getVoicesCache();
|
||||
cache[model] = defaultVoices;
|
||||
saveVoicesCache(cache);
|
||||
selectedVoiceStore.set(defaultVoices[0]);
|
||||
return;
|
||||
}
|
||||
const data = await response.json();
|
||||
// Expect response to be an array of voice strings or an object with a voices array
|
||||
const voices = Array.isArray(data) ? data : (data.voices || defaultVoices);
|
||||
const newVoices = voices.length > 0 ? voices : defaultVoices;
|
||||
|
||||
availableVoices = newVoices;
|
||||
const cache = getVoicesCache();
|
||||
cache[model] = newVoices;
|
||||
saveVoicesCache(cache);
|
||||
|
||||
// Reset to first available voice
|
||||
selectedVoiceStore.set(newVoices[0]);
|
||||
} catch {
|
||||
// Fall back to default voices on error
|
||||
availableVoices = defaultVoices;
|
||||
const cache = getVoicesCache();
|
||||
cache[model] = defaultVoices;
|
||||
saveVoicesCache(cache);
|
||||
selectedVoiceStore.set(defaultVoices[0]);
|
||||
} finally {
|
||||
isLoadingVoices = false;
|
||||
}
|
||||
}
|
||||
|
||||
function handleVoiceChange(event: Event) {
|
||||
const value = (event.target as HTMLSelectElement).value;
|
||||
if (value === "(refresh)") {
|
||||
refreshVoices();
|
||||
} else {
|
||||
selectedVoiceStore.set(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-play effect when new audio is generated
|
||||
$effect(() => {
|
||||
if (generatedAudioUrl && $autoPlayStore && audioElement) {
|
||||
audioElement.load();
|
||||
audioElement.play().catch(() => {
|
||||
// Ignore auto-play errors (e.g., browser policy blocks)
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
async function generate() {
|
||||
const trimmedText = inputText.trim();
|
||||
if (!trimmedText || !$selectedModelStore || isGenerating) return;
|
||||
|
||||
isGenerating = true;
|
||||
error = null;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
const audioBlob = await generateSpeech(
|
||||
$selectedModelStore,
|
||||
trimmedText,
|
||||
$selectedVoiceStore,
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
// Revoke previous URL to prevent memory leaks
|
||||
if (generatedAudioUrl) {
|
||||
URL.revokeObjectURL(generatedAudioUrl);
|
||||
}
|
||||
|
||||
// Create object URL for the audio blob and store metadata
|
||||
generatedAudioUrl = URL.createObjectURL(audioBlob);
|
||||
generatedVoice = $selectedVoiceStore;
|
||||
generatedTimestamp = new Date();
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// User cancelled
|
||||
} else {
|
||||
error = err instanceof Error ? err.message : "An error occurred";
|
||||
}
|
||||
} finally {
|
||||
isGenerating = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function cancelGeneration() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function clearInput() {
|
||||
inputText = "";
|
||||
}
|
||||
|
||||
function downloadAudio() {
|
||||
if (!generatedAudioUrl) return;
|
||||
|
||||
const timestamp = (generatedTimestamp || new Date()).toISOString().replace(/[:.]/g, '-').slice(0, -5);
|
||||
const voice = generatedVoice || 'speech';
|
||||
const filename = `${voice}-${timestamp}.mp3`;
|
||||
|
||||
const a = document.createElement('a');
|
||||
a.href = generatedAudioUrl;
|
||||
a.download = filename;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
}
|
||||
|
||||
function formatTimestamp(date: Date): string {
|
||||
return date.toLocaleString(undefined, {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: 'numeric',
|
||||
minute: '2-digit',
|
||||
hour12: true
|
||||
});
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
generate();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model and voice selectors -->
|
||||
<div class="shrink-0 flex gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a speech model..." disabled={isGenerating} />
|
||||
<div class="flex gap-2">
|
||||
<select
|
||||
class="shrink-0 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
value={$selectedVoiceStore}
|
||||
onchange={handleVoiceChange}
|
||||
disabled={isGenerating || isLoadingVoices || !$selectedModelStore}
|
||||
>
|
||||
{#each availableVoices as voice (voice)}
|
||||
<option value={voice}>{voice}</option>
|
||||
{/each}
|
||||
<option value="(refresh)">(refresh)</option>
|
||||
</select>
|
||||
{#if $selectedModelStore && !getVoicesCache()[$selectedModelStore]}
|
||||
<button
|
||||
class="btn shrink-0"
|
||||
onclick={refreshVoices}
|
||||
disabled={isLoadingVoices}
|
||||
title={isLoadingVoices ? "Loading voices..." : "Load voices for this model"}
|
||||
>
|
||||
{#if isLoadingVoices}
|
||||
<svg class="w-5 h-5 animate-spin" fill="none" viewBox="0 0 24 24">
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||
</svg>
|
||||
{:else}
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"></path>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to generate speech.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Audio display area -->
|
||||
<div class="shrink-0 mb-4 bg-surface border border-gray-200 dark:border-white/10 rounded p-4 md:p-6">
|
||||
{#if isGenerating}
|
||||
<div class="flex items-center justify-center text-txtsecondary py-8">
|
||||
<div class="text-center">
|
||||
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
|
||||
<p>Generating speech...</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="flex items-center justify-center py-8">
|
||||
<div class="text-center text-red-500">
|
||||
<p class="font-medium">Error</p>
|
||||
<p class="text-sm mt-1">{error}</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else if generatedAudioUrl}
|
||||
<div class="flex flex-col gap-4">
|
||||
<!-- Header with metadata and download -->
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<div class="flex flex-wrap gap-3 text-sm text-txtsecondary">
|
||||
{#if generatedVoice}
|
||||
<span class="flex items-center gap-1">
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11a7 7 0 01-7 7m0 0a7 7 0 01-7-7m7 7v4m0 0H8m4 0h4m-4-8a3 3 0 01-3-3V5a3 3 0 116 0v6a3 3 0 01-3 3z"></path>
|
||||
</svg>
|
||||
{generatedVoice}
|
||||
</span>
|
||||
{/if}
|
||||
{#if generatedTimestamp}
|
||||
<span class="flex items-center gap-1">
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z"></path>
|
||||
</svg>
|
||||
{formatTimestamp(generatedTimestamp)}
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
<button
|
||||
class="btn shrink-0"
|
||||
onclick={downloadAudio}
|
||||
title="Download audio file"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Audio player with larger controls -->
|
||||
<div class="w-full">
|
||||
<audio bind:this={audioElement} controls class="w-full h-12 md:h-16">
|
||||
<source src={generatedAudioUrl} type="audio/mpeg" />
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center text-txtsecondary py-8">
|
||||
<div class="text-center">
|
||||
<svg class="w-12 h-12 md:w-16 md:h-16 mx-auto mb-2 opacity-40" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11a7 7 0 01-7 7m0 0a7 7 0 01-7-7m7 7v4m0 0H8m4 0h4m-4-8a3 3 0 01-3-3V5a3 3 0 116 0v6a3 3 0 01-3 3z"></path>
|
||||
</svg>
|
||||
<p>Enter text below to convert to speech</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Text input area -->
|
||||
<div class="flex-1 flex flex-col md:flex-row gap-2 min-h-0">
|
||||
<ExpandableTextarea
|
||||
bind:value={inputText}
|
||||
placeholder="Enter text to convert to speech..."
|
||||
rows={8}
|
||||
onkeydown={handleKeyDown}
|
||||
disabled={isGenerating || !$selectedModelStore}
|
||||
/>
|
||||
<div class="shrink-0 flex md:flex-col gap-2">
|
||||
{#if isGenerating}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white flex-1 md:flex-none" onclick={cancelGeneration}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90 flex-1 md:flex-none"
|
||||
onclick={generate}
|
||||
disabled={!inputText.trim() || !$selectedModelStore}
|
||||
>
|
||||
Generate
|
||||
</button>
|
||||
<button
|
||||
class="btn flex-1 md:flex-none"
|
||||
onclick={clearInput}
|
||||
disabled={!inputText.trim()}
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
<label class="flex items-center justify-center gap-2 text-sm cursor-pointer">
|
||||
<input
|
||||
type="checkbox"
|
||||
bind:checked={$autoPlayStore}
|
||||
class="cursor-pointer"
|
||||
/>
|
||||
Auto-play
|
||||
</label>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -1,4 +1,5 @@
|
||||
@import "tailwindcss";
|
||||
@import "katex/dist/katex.min.css";
|
||||
@custom-variant dark (&:where([data-theme=dark], [data-theme=dark] *));
|
||||
|
||||
@theme {
|
||||
@@ -0,0 +1,24 @@
|
||||
import type { AudioTranscriptionResponse } from "./types";
|
||||
|
||||
export async function transcribeAudio(
|
||||
model: string,
|
||||
file: File,
|
||||
signal?: AbortSignal
|
||||
): Promise<AudioTranscriptionResponse> {
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
formData.append("model", model);
|
||||
|
||||
const response = await fetch("/v1/audio/transcriptions", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Audio API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
import type { ChatMessage, ChatCompletionRequest } from "./types";
|
||||
|
||||
export interface StreamChunk {
|
||||
content: string;
|
||||
reasoning_content?: string;
|
||||
done: boolean;
|
||||
}
|
||||
|
||||
export interface ChatOptions {
|
||||
temperature?: number;
|
||||
}
|
||||
|
||||
function parseSSELine(line: string): StreamChunk | null {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || !trimmed.startsWith("data: ")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") {
|
||||
return { content: "", done: true };
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
const delta = parsed.choices?.[0]?.delta;
|
||||
const content = delta?.content || "";
|
||||
const reasoning_content = delta?.reasoning_content || "";
|
||||
|
||||
if (content || reasoning_content) {
|
||||
return { content, reasoning_content, done: false };
|
||||
}
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function* streamChatCompletion(
|
||||
model: string,
|
||||
messages: ChatMessage[],
|
||||
signal?: AbortSignal,
|
||||
options?: ChatOptions
|
||||
): AsyncGenerator<StreamChunk> {
|
||||
const request: ChatCompletionRequest = {
|
||||
model,
|
||||
messages,
|
||||
stream: true,
|
||||
temperature: options?.temperature,
|
||||
};
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Chat API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("Response body is not readable");
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
const result = parseSSELine(line);
|
||||
if (result?.done) {
|
||||
yield result;
|
||||
return;
|
||||
}
|
||||
if (result) {
|
||||
yield result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process any remaining buffer
|
||||
const result = parseSSELine(buffer);
|
||||
if (result && !result.done) {
|
||||
yield result;
|
||||
}
|
||||
|
||||
yield { content: "", done: true };
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
import type { ImageGenerationRequest, ImageGenerationResponse } from "./types";
|
||||
|
||||
export async function generateImage(
|
||||
model: string,
|
||||
prompt: string,
|
||||
size: string,
|
||||
signal?: AbortSignal
|
||||
): Promise<ImageGenerationResponse> {
|
||||
const request: ImageGenerationRequest = {
|
||||
model,
|
||||
prompt,
|
||||
n: 1,
|
||||
size,
|
||||
};
|
||||
|
||||
const response = await fetch("/v1/images/generations", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Image API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
@@ -0,0 +1,423 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { renderMarkdown, escapeHtml, splitCompleteBlocks, closePendingBlock, normalizeLatexDelimiters, renderStreamingMarkdown, createStreamingCache } from "./markdown";
|
||||
|
||||
describe("renderMarkdown", () => {
|
||||
describe("basic markdown", () => {
|
||||
it("renders plain text", () => {
|
||||
const result = renderMarkdown("Hello world");
|
||||
expect(result).toContain("Hello world");
|
||||
});
|
||||
|
||||
it("renders bold text", () => {
|
||||
const result = renderMarkdown("**bold**");
|
||||
expect(result).toContain("<strong>bold</strong>");
|
||||
});
|
||||
|
||||
it("renders italic text", () => {
|
||||
const result = renderMarkdown("*italic*");
|
||||
expect(result).toContain("<em>italic</em>");
|
||||
});
|
||||
|
||||
it("renders code blocks", () => {
|
||||
const result = renderMarkdown("```js\nconst x = 1;\n```");
|
||||
expect(result).toContain("hljs");
|
||||
expect(result).toContain("const");
|
||||
});
|
||||
|
||||
it("returns empty string for empty content", () => {
|
||||
const result = renderMarkdown("");
|
||||
expect(result).toBe("");
|
||||
});
|
||||
|
||||
it("returns empty string for null/undefined content", () => {
|
||||
// @ts-expect-error - testing null input
|
||||
expect(renderMarkdown(null)).toBe("");
|
||||
// @ts-expect-error - testing undefined input
|
||||
expect(renderMarkdown(undefined)).toBe("");
|
||||
});
|
||||
});
|
||||
|
||||
describe("KaTeX math rendering", () => {
|
||||
it("renders inline math with $...$ syntax", () => {
|
||||
const result = renderMarkdown("The equation $E = mc^2$ is famous.");
|
||||
// KaTeX should convert this to HTML with katex class
|
||||
expect(result).toContain("katex");
|
||||
expect(result).toContain("E");
|
||||
expect(result).toContain("=");
|
||||
expect(result).toContain("mc");
|
||||
});
|
||||
|
||||
it("renders display math with $$...$$ syntax", () => {
|
||||
const result = renderMarkdown("$$\\int_{a}^{b} f(x) dx$$");
|
||||
// Math should be rendered with KaTeX
|
||||
expect(result).toContain("katex");
|
||||
expect(result).toContain("∫");
|
||||
expect(result).toContain("f(x)");
|
||||
});
|
||||
|
||||
it("renders complex LaTeX expressions", () => {
|
||||
const result = renderMarkdown("$$\\sum_{i=1}^{n} x_i = \\frac{1}{n}\\sum_{i=1}^{n} x_i$$");
|
||||
expect(result).toContain("katex");
|
||||
expect(result).toContain("∑"); // or the MathML equivalent
|
||||
});
|
||||
|
||||
it("renders LaTeX with Greek letters", () => {
|
||||
const result = renderMarkdown("$\\alpha + \\beta = \\gamma$");
|
||||
expect(result).toContain("katex");
|
||||
// Greek letters should be rendered
|
||||
expect(result).toMatch(/[αβγ]|alpha|beta|gamma/);
|
||||
});
|
||||
|
||||
it("renders LaTeX with fractions", () => {
|
||||
const result = renderMarkdown("$\\frac{a}{b}$");
|
||||
expect(result).toContain("katex");
|
||||
expect(result).toContain("frac");
|
||||
});
|
||||
|
||||
it("renders LaTeX with subscripts and superscripts", () => {
|
||||
const result = renderMarkdown("$x^2 + y_3$");
|
||||
expect(result).toContain("katex");
|
||||
expect(result).toContain("sup"); // superscript
|
||||
expect(result).toContain("sub"); // subscript
|
||||
});
|
||||
|
||||
it("renders multiple inline math expressions in one paragraph", () => {
|
||||
const result = renderMarkdown("First $x = 1$ and then $y = 2$.");
|
||||
// Should contain multiple katex spans
|
||||
const katexMatches = result.match(/katex/g);
|
||||
expect(katexMatches?.length).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("renders math within a larger markdown document", () => {
|
||||
const markdown = `# Heading
|
||||
|
||||
This is a paragraph with $E = mc^2$ inline math.
|
||||
|
||||
$$\\int_0^\\infty e^{-x} dx = 1$$
|
||||
|
||||
More text here.
|
||||
`;
|
||||
const result = renderMarkdown(markdown);
|
||||
expect(result).toContain("<h1>Heading</h1>");
|
||||
expect(result).toContain("katex");
|
||||
// Both inline and display math should be rendered
|
||||
expect(result).toContain("E = mc");
|
||||
expect(result).toContain("∫");
|
||||
expect(result).toContain("∞");
|
||||
});
|
||||
|
||||
it("handles escaped dollar signs", () => {
|
||||
const result = renderMarkdown("This costs \\$5 and $x = 1$.");
|
||||
// Should render the escaped $5 as text and the math
|
||||
expect(result).toContain("katex");
|
||||
expect(result).toContain("$5");
|
||||
});
|
||||
|
||||
it("handles empty math expressions gracefully", () => {
|
||||
// Empty math should not break the renderer
|
||||
const result = renderMarkdown("$$$");
|
||||
expect(result).toBeTruthy();
|
||||
});
|
||||
|
||||
it("renders LaTeX matrices", () => {
|
||||
const result = renderMarkdown("$$\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}$$");
|
||||
expect(result).toContain("katex");
|
||||
expect(result).toContain("pmatrix");
|
||||
});
|
||||
|
||||
it("renders LaTeX square roots", () => {
|
||||
const result = renderMarkdown("$\\sqrt{x^2 + y^2}$");
|
||||
expect(result).toContain("katex");
|
||||
expect(result).toContain("sqrt");
|
||||
});
|
||||
|
||||
it("renders \\[...\\] display math", () => {
|
||||
const result = renderMarkdown("\\[\nx^2 + y^2 = z^2\n\\]");
|
||||
expect(result).toContain("katex");
|
||||
});
|
||||
|
||||
it("renders \\(...\\) inline math", () => {
|
||||
const result = renderMarkdown("The equation \\(E = mc^2\\) is famous.");
|
||||
expect(result).toContain("katex");
|
||||
});
|
||||
});
|
||||
|
||||
describe("normalizeLatexDelimiters", () => {
|
||||
it("converts \\[...\\] to $$...$$", () => {
|
||||
expect(normalizeLatexDelimiters("\\[\nx^2\n\\]")).toBe("$$\nx^2\n$$");
|
||||
});
|
||||
|
||||
it("converts \\(...\\) to $...$", () => {
|
||||
expect(normalizeLatexDelimiters("\\(x^2\\)")).toBe("$x^2$");
|
||||
});
|
||||
|
||||
it("leaves $$ and $ delimiters unchanged", () => {
|
||||
const text = "$$x^2$$ and $y$";
|
||||
expect(normalizeLatexDelimiters(text)).toBe(text);
|
||||
});
|
||||
|
||||
it("handles multiple occurrences", () => {
|
||||
expect(normalizeLatexDelimiters("\\(a\\) and \\(b\\)")).toBe("$a$ and $b$");
|
||||
});
|
||||
});
|
||||
|
||||
describe("escapeHtml", () => {
|
||||
it("escapes HTML entities", () => {
|
||||
expect(escapeHtml("<script>")).toBe("<script>");
|
||||
expect(escapeHtml('"quoted"')).toBe(""quoted"");
|
||||
expect(escapeHtml("'single'")).toBe("'single'");
|
||||
expect(escapeHtml("a & b")).toBe("a & b");
|
||||
});
|
||||
|
||||
it("handles empty string", () => {
|
||||
expect(escapeHtml("")).toBe("");
|
||||
});
|
||||
});
|
||||
|
||||
describe("error handling", () => {
|
||||
it("does not throw on invalid LaTeX syntax", () => {
|
||||
// Invalid LaTeX should not crash the renderer
|
||||
expect(() => renderMarkdown("$\\invalidcommand{")).not.toThrow();
|
||||
});
|
||||
|
||||
it("returns fallback HTML on processing errors", () => {
|
||||
// Very large or malformed input should be handled
|
||||
const result = renderMarkdown("$" + "a".repeat(10000) + "$");
|
||||
expect(result).toBeTruthy();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("splitCompleteBlocks", () => {
|
||||
it("returns everything as pending when no blank line", () => {
|
||||
const result = splitCompleteBlocks("Hello world");
|
||||
expect(result.complete).toBe("");
|
||||
expect(result.pending).toBe("Hello world");
|
||||
});
|
||||
|
||||
it("returns empty for empty input", () => {
|
||||
const result = splitCompleteBlocks("");
|
||||
expect(result.complete).toBe("");
|
||||
expect(result.pending).toBe("");
|
||||
});
|
||||
|
||||
it("splits on blank line between paragraphs", () => {
|
||||
const result = splitCompleteBlocks("First paragraph.\n\nSecond paragraph");
|
||||
expect(result.complete).toBe("First paragraph.\n");
|
||||
expect(result.pending).toBe("Second paragraph");
|
||||
});
|
||||
|
||||
it("splits multiple paragraphs at last blank line", () => {
|
||||
const result = splitCompleteBlocks("Para 1.\n\nPara 2.\n\nPara 3");
|
||||
expect(result.complete).toBe("Para 1.\n\nPara 2.\n");
|
||||
expect(result.pending).toBe("Para 3");
|
||||
});
|
||||
|
||||
it("treats closed code fence as complete boundary", () => {
|
||||
const text = "```js\nconst x = 1;\n```\nMore text";
|
||||
const result = splitCompleteBlocks(text);
|
||||
expect(result.complete).toBe("```js\nconst x = 1;\n```");
|
||||
expect(result.pending).toBe("More text");
|
||||
});
|
||||
|
||||
it("treats unclosed code fence as pending", () => {
|
||||
const text = "Done paragraph.\n\n```js\nconst x = 1;";
|
||||
const result = splitCompleteBlocks(text);
|
||||
expect(result.complete).toBe("Done paragraph.\n");
|
||||
expect(result.pending).toBe("```js\nconst x = 1;");
|
||||
});
|
||||
|
||||
it("does not split on blank lines inside code fences", () => {
|
||||
const text = "```\nline1\n\nline2\n```";
|
||||
const result = splitCompleteBlocks(text);
|
||||
expect(result.complete).toBe("```\nline1\n\nline2\n```");
|
||||
expect(result.pending).toBe("");
|
||||
});
|
||||
|
||||
it("handles tilde fences", () => {
|
||||
const text = "~~~py\nprint('hi')\n~~~\nAfter";
|
||||
const result = splitCompleteBlocks(text);
|
||||
expect(result.complete).toBe("~~~py\nprint('hi')\n~~~");
|
||||
expect(result.pending).toBe("After");
|
||||
});
|
||||
|
||||
it("does not close backtick fence with tilde fence", () => {
|
||||
const text = "```\ncode\n~~~\nstill code";
|
||||
const result = splitCompleteBlocks(text);
|
||||
// The ~~~ should not close a backtick fence, so everything from ``` onward is pending
|
||||
expect(result.complete).toBe("");
|
||||
expect(result.pending).toBe("```\ncode\n~~~\nstill code");
|
||||
});
|
||||
|
||||
it("treats closed math block as complete boundary", () => {
|
||||
const text = "$$\nx^2\n$$\nAfter";
|
||||
const result = splitCompleteBlocks(text);
|
||||
expect(result.complete).toBe("$$\nx^2\n$$");
|
||||
expect(result.pending).toBe("After");
|
||||
});
|
||||
|
||||
it("treats unclosed math block as pending", () => {
|
||||
const text = "Before.\n\n$$\nx^2";
|
||||
const result = splitCompleteBlocks(text);
|
||||
expect(result.complete).toBe("Before.\n");
|
||||
expect(result.pending).toBe("$$\nx^2");
|
||||
});
|
||||
|
||||
it("treats closed \\[...\\] math block as complete boundary", () => {
|
||||
const text = "\\[\nx^2\n\\]\nAfter";
|
||||
const result = splitCompleteBlocks(text);
|
||||
expect(result.complete).toBe("\\[\nx^2\n\\]");
|
||||
expect(result.pending).toBe("After");
|
||||
});
|
||||
|
||||
it("treats unclosed \\[ math block as pending", () => {
|
||||
const text = "Before.\n\n\\[\nx^2";
|
||||
const result = splitCompleteBlocks(text);
|
||||
expect(result.complete).toBe("Before.\n");
|
||||
expect(result.pending).toBe("\\[\nx^2");
|
||||
});
|
||||
|
||||
it("handles trailing blank line making everything complete", () => {
|
||||
const text = "Hello world.\n";
|
||||
const result = splitCompleteBlocks(text);
|
||||
// Last line is empty string after split, which is a blank line
|
||||
expect(result.complete).toBe("Hello world.\n");
|
||||
expect(result.pending).toBe("");
|
||||
});
|
||||
});
|
||||
|
||||
describe("closePendingBlock", () => {
|
||||
it("returns empty string for empty input", () => {
|
||||
expect(closePendingBlock("")).toBe("");
|
||||
});
|
||||
|
||||
it("returns plain text unchanged", () => {
|
||||
expect(closePendingBlock("Hello world")).toBe("Hello world");
|
||||
});
|
||||
|
||||
it("closes an open backtick code fence", () => {
|
||||
const result = closePendingBlock("```python\nprint('hi')");
|
||||
expect(result).toBe("```python\nprint('hi')\n```");
|
||||
});
|
||||
|
||||
it("closes an open tilde code fence", () => {
|
||||
const result = closePendingBlock("~~~js\nconst x = 1;");
|
||||
expect(result).toBe("~~~js\nconst x = 1;\n~~~");
|
||||
});
|
||||
|
||||
it("does not modify already-closed code fence", () => {
|
||||
const text = "```py\ncode\n```";
|
||||
expect(closePendingBlock(text)).toBe(text);
|
||||
});
|
||||
|
||||
it("closes an open math block", () => {
|
||||
const result = closePendingBlock("$$\nx^2 + y^2");
|
||||
expect(result).toBe("$$\nx^2 + y^2\n$$");
|
||||
});
|
||||
|
||||
it("does not modify already-closed math block", () => {
|
||||
const text = "$$\nx^2\n$$";
|
||||
expect(closePendingBlock(text)).toBe(text);
|
||||
});
|
||||
|
||||
it("closes an open \\[ math block with \\]", () => {
|
||||
const result = closePendingBlock("\\[\nx^2 + y^2");
|
||||
expect(result).toBe("\\[\nx^2 + y^2\n\\]");
|
||||
});
|
||||
|
||||
it("does not modify already-closed \\[...\\] math block", () => {
|
||||
const text = "\\[\nx^2\n\\]";
|
||||
expect(closePendingBlock(text)).toBe(text);
|
||||
});
|
||||
|
||||
it("closes code fence when preceded by regular text", () => {
|
||||
const result = closePendingBlock("Some text\n```\ncode");
|
||||
expect(result).toBe("Some text\n```\ncode\n```");
|
||||
});
|
||||
|
||||
it("leaves headers unchanged", () => {
|
||||
expect(closePendingBlock("## Hello")).toBe("## Hello");
|
||||
});
|
||||
|
||||
it("leaves tables unchanged", () => {
|
||||
const table = "| a | b |\n| --- | --- |\n| 1 | 2 |";
|
||||
expect(closePendingBlock(table)).toBe(table);
|
||||
});
|
||||
|
||||
it("leaves lists unchanged", () => {
|
||||
expect(closePendingBlock("- item 1\n- item 2")).toBe("- item 1\n- item 2");
|
||||
});
|
||||
});
|
||||
|
||||
describe("renderStreamingMarkdown", () => {
|
||||
it("renders complete blocks and pending as markdown", () => {
|
||||
const cache = createStreamingCache();
|
||||
const text = "# Hello\n\nWorld";
|
||||
const { blocks, pendingHtml } = renderStreamingMarkdown(text, cache);
|
||||
expect(blocks).toHaveLength(1);
|
||||
expect(blocks[0].html).toContain("<h1>Hello</h1>");
|
||||
expect(pendingHtml).toContain("World");
|
||||
expect(pendingHtml).toContain("<p>");
|
||||
});
|
||||
|
||||
it("preserves existing blocks when complete portion is unchanged", () => {
|
||||
const cache = createStreamingCache();
|
||||
renderStreamingMarkdown("# Hello\n\nWor", cache);
|
||||
const firstBlocks = cache.blocks;
|
||||
|
||||
const { blocks } = renderStreamingMarkdown("# Hello\n\nWorld", cache);
|
||||
// Same block array reference — nothing changed in the complete section
|
||||
expect(blocks).toBe(firstBlocks);
|
||||
expect(cache.completeKey).toBe("# Hello\n");
|
||||
});
|
||||
|
||||
it("appends a new block when a new section completes", () => {
|
||||
const cache = createStreamingCache();
|
||||
renderStreamingMarkdown("# Hello\n\nParagraph", cache);
|
||||
expect(cache.blocks).toHaveLength(1);
|
||||
const firstBlock = cache.blocks[0];
|
||||
|
||||
renderStreamingMarkdown("# Hello\n\nParagraph.\n\nMore", cache);
|
||||
expect(cache.blocks).toHaveLength(2);
|
||||
// First block is preserved with the same id and html
|
||||
expect(cache.blocks[0].id).toBe(firstBlock.id);
|
||||
expect(cache.blocks[0].html).toBe(firstBlock.html);
|
||||
// Second block contains the new paragraph
|
||||
expect(cache.blocks[1].html).toContain("Paragraph.");
|
||||
});
|
||||
|
||||
it("assigns unique stable ids to each block", () => {
|
||||
const cache = createStreamingCache();
|
||||
renderStreamingMarkdown("A.\n\nB.\n\nC", cache);
|
||||
expect(cache.blocks).toHaveLength(1);
|
||||
const id0 = cache.blocks[0].id;
|
||||
|
||||
renderStreamingMarkdown("A.\n\nB.\n\nC.\n\nD", cache);
|
||||
expect(cache.blocks).toHaveLength(2);
|
||||
expect(cache.blocks[0].id).toBe(id0);
|
||||
expect(cache.blocks[1].id).toBe(id0 + 1);
|
||||
});
|
||||
|
||||
it("renders pending code block with syntax highlighting", () => {
|
||||
const cache = createStreamingCache();
|
||||
const text = "Done.\n\n```python\nprint('hello')";
|
||||
const { pendingHtml } = renderStreamingMarkdown(text, cache);
|
||||
expect(pendingHtml).toContain("<code");
|
||||
expect(pendingHtml).toContain("hljs");
|
||||
});
|
||||
|
||||
it("renders pending table as markdown", () => {
|
||||
const cache = createStreamingCache();
|
||||
const text = "Done.\n\n| a | b |\n| --- | --- |\n| 1 | 2 |";
|
||||
const { pendingHtml } = renderStreamingMarkdown(text, cache);
|
||||
expect(pendingHtml).toContain("<table>");
|
||||
expect(pendingHtml).toContain("<td>");
|
||||
});
|
||||
|
||||
it("renders pending portion through markdown pipeline", () => {
|
||||
const cache = createStreamingCache();
|
||||
const text = "Done.\n\nSome **bold** text";
|
||||
const { pendingHtml } = renderStreamingMarkdown(text, cache);
|
||||
expect(pendingHtml).toContain("<strong>bold</strong>");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,260 @@
|
||||
import { unified } from "unified";
|
||||
import remarkParse from "remark-parse";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
import remarkRehype from "remark-rehype";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import rehypeStringify from "rehype-stringify";
|
||||
import hljs from "highlight.js";
|
||||
import { visit } from "unist-util-visit";
|
||||
import type { Element, Root } from "hast";
|
||||
|
||||
// Custom plugin to highlight code blocks with highlight.js
|
||||
function rehypeHighlight() {
|
||||
return (tree: Root) => {
|
||||
visit(tree, "element", (node: Element) => {
|
||||
if (node.tagName === "code" && node.properties) {
|
||||
const className = node.properties.className;
|
||||
const classes = Array.isArray(className)
|
||||
? className.filter((c): c is string => typeof c === "string")
|
||||
: [];
|
||||
const lang = classes
|
||||
.find((c) => c.startsWith("language-"))
|
||||
?.replace("language-", "");
|
||||
|
||||
const text = node.children
|
||||
.filter((child): child is { type: "text"; value: string } => child.type === "text")
|
||||
.map((child) => child.value)
|
||||
.join("");
|
||||
|
||||
if (text) {
|
||||
const language = lang && hljs.getLanguage(lang) ? lang : "plaintext";
|
||||
const highlighted = hljs.highlight(text, { language }).value;
|
||||
|
||||
// Replace the text node with raw HTML
|
||||
node.properties.className = [
|
||||
"hljs",
|
||||
`language-${language}`,
|
||||
...classes.filter((c) => !c.startsWith("language-")),
|
||||
];
|
||||
// Use type assertion since we're modifying the tree structure
|
||||
(node.children as unknown) = [
|
||||
{ type: "raw", value: highlighted },
|
||||
];
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
export function escapeHtml(text: string): string {
|
||||
const htmlEntities: Record<string, string> = {
|
||||
"&": "&",
|
||||
"<": "<",
|
||||
">": ">",
|
||||
'"': """,
|
||||
"'": "'",
|
||||
};
|
||||
return text.replace(/[&<>"']/g, (char) => htmlEntities[char]);
|
||||
}
|
||||
|
||||
// Create the unified processor
|
||||
const processor = unified()
|
||||
.use(remarkParse)
|
||||
.use(remarkGfm)
|
||||
.use(remarkMath)
|
||||
.use(remarkRehype, { allowDangerousHtml: true })
|
||||
.use(rehypeKatex)
|
||||
.use(rehypeHighlight)
|
||||
.use(rehypeStringify, { allowDangerousHtml: true });
|
||||
|
||||
export function splitCompleteBlocks(text: string): { complete: string; pending: string } {
|
||||
if (!text) {
|
||||
return { complete: "", pending: "" };
|
||||
}
|
||||
|
||||
const lines = text.split("\n");
|
||||
let lastCompleteBoundary = -1; // index of last line that ends a complete block
|
||||
let inFence = false;
|
||||
let fenceChar = "";
|
||||
let inMathBlock = false;
|
||||
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const trimmed = lines[i].trimEnd();
|
||||
|
||||
if (inFence) {
|
||||
// Check for closing fence: same character, at least 3, no other content
|
||||
if (new RegExp(`^\\s*${fenceChar.replace(/~/g, "\\~")}{3,}\\s*$`).test(trimmed)) {
|
||||
inFence = false;
|
||||
fenceChar = "";
|
||||
lastCompleteBoundary = i;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inMathBlock) {
|
||||
if (trimmed === "$$" || trimmed === "\\]") {
|
||||
inMathBlock = false;
|
||||
lastCompleteBoundary = i;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for opening fence
|
||||
const fenceMatch = trimmed.match(/^(\s*)(```|~~~)/);
|
||||
if (fenceMatch) {
|
||||
// Check if it's an opening fence (may have language info after)
|
||||
// A line with just ``` or ~~~ could be opening or closing, but since we're not in a fence it's opening
|
||||
fenceChar = fenceMatch[2][0]; // '`' or '~'
|
||||
inFence = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for opening math block
|
||||
if (trimmed === "$$" || trimmed === "\\[") {
|
||||
inMathBlock = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Outside fences/math: blank line marks a complete boundary
|
||||
if (trimmed === "") {
|
||||
lastCompleteBoundary = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (lastCompleteBoundary < 0) {
|
||||
return { complete: "", pending: text };
|
||||
}
|
||||
|
||||
const completeLines = lines.slice(0, lastCompleteBoundary + 1);
|
||||
const pendingLines = lines.slice(lastCompleteBoundary + 1);
|
||||
|
||||
return {
|
||||
complete: completeLines.join("\n"),
|
||||
pending: pendingLines.join("\n"),
|
||||
};
|
||||
}
|
||||
|
||||
export function closePendingBlock(pending: string): string {
|
||||
if (!pending) return "";
|
||||
|
||||
const lines = pending.split("\n");
|
||||
let inFence = false;
|
||||
let fenceStr = "";
|
||||
let inMathBlock = false;
|
||||
let mathClose = "";
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trimEnd();
|
||||
|
||||
if (inFence) {
|
||||
if (new RegExp(`^\\s*${fenceStr[0] === "~" ? "~~~" : "\\`\\`\\`"}\\s*$`).test(trimmed)) {
|
||||
inFence = false;
|
||||
fenceStr = "";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inMathBlock) {
|
||||
if (trimmed === "$$" || trimmed === "\\]") {
|
||||
inMathBlock = false;
|
||||
mathClose = "";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const fenceMatch = trimmed.match(/^(\s*)(```|~~~)/);
|
||||
if (fenceMatch) {
|
||||
fenceStr = fenceMatch[2];
|
||||
inFence = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (trimmed === "$$") {
|
||||
inMathBlock = true;
|
||||
mathClose = "$$";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (trimmed === "\\[") {
|
||||
inMathBlock = true;
|
||||
mathClose = "\\]";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (inFence) return pending + "\n" + fenceStr;
|
||||
if (inMathBlock) return pending + "\n" + mathClose;
|
||||
return pending;
|
||||
}
|
||||
|
||||
export interface RenderedBlock {
|
||||
id: number;
|
||||
html: string;
|
||||
}
|
||||
|
||||
export interface StreamingCache {
|
||||
blocks: RenderedBlock[];
|
||||
nextId: number;
|
||||
completeKey: string;
|
||||
}
|
||||
|
||||
export function createStreamingCache(): StreamingCache {
|
||||
return { blocks: [], nextId: 0, completeKey: "" };
|
||||
}
|
||||
|
||||
export function renderStreamingMarkdown(
|
||||
text: string,
|
||||
cache: StreamingCache,
|
||||
): { blocks: RenderedBlock[]; pendingHtml: string } {
|
||||
const { complete, pending } = splitCompleteBlocks(text);
|
||||
|
||||
if (complete) {
|
||||
if (cache.completeKey !== complete) {
|
||||
if (complete.startsWith(cache.completeKey) && cache.completeKey.length > 0) {
|
||||
// Complete section grew — render only the new part as a new block
|
||||
const newPart = complete.slice(cache.completeKey.length);
|
||||
cache.blocks = [...cache.blocks, { id: cache.nextId++, html: renderMarkdown(newPart) }];
|
||||
} else {
|
||||
// Complete section changed unexpectedly — re-render as single block
|
||||
cache.blocks = [{ id: cache.nextId++, html: renderMarkdown(complete) }];
|
||||
}
|
||||
cache.completeKey = complete;
|
||||
}
|
||||
} else if (cache.blocks.length > 0) {
|
||||
cache.blocks = [];
|
||||
cache.completeKey = "";
|
||||
}
|
||||
|
||||
let pendingHtml = "";
|
||||
if (pending) {
|
||||
const closed = closePendingBlock(pending);
|
||||
pendingHtml = renderMarkdown(closed);
|
||||
}
|
||||
|
||||
return { blocks: cache.blocks, pendingHtml };
|
||||
}
|
||||
|
||||
// Convert \[...\] to $$...$$ and \(...\) to $...$
|
||||
export function normalizeLatexDelimiters(text: string): string {
|
||||
// Display math: \[...\] → $$...$$ (may span multiple lines)
|
||||
text = text.replace(/\\\[([\s\S]*?)\\\]/g, (_match, inner) => `$$${inner}$$`);
|
||||
// Inline math: \(...\) → $...$
|
||||
text = text.replace(/\\\(([\s\S]*?)\\\)/g, (_match, inner) => `$${inner}$`);
|
||||
return text;
|
||||
}
|
||||
|
||||
export function renderMarkdown(content: string): string {
|
||||
if (!content) {
|
||||
return "";
|
||||
}
|
||||
|
||||
try {
|
||||
const result = processor.processSync(normalizeLatexDelimiters(content));
|
||||
return String(result);
|
||||
} catch {
|
||||
// Fallback to escaped plain text if markdown parsing fails
|
||||
return `<p>${escapeHtml(content)}</p>`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
import type { Model } from "./types";
|
||||
|
||||
export interface GroupedModels {
|
||||
local: Model[];
|
||||
peersByProvider: Record<string, Model[]>;
|
||||
}
|
||||
|
||||
export function groupModels(models: Model[]): GroupedModels {
|
||||
const available = models.filter((m) => !m.unlisted);
|
||||
const local = available.filter((m) => !m.peerID);
|
||||
const peerModels = available.filter((m) => m.peerID);
|
||||
|
||||
const peersByProvider = peerModels.reduce(
|
||||
(acc, model) => {
|
||||
const peerId = model.peerID || "unknown";
|
||||
if (!acc[peerId]) acc[peerId] = [];
|
||||
acc[peerId].push(model);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, Model[]>
|
||||
);
|
||||
|
||||
return { local, peersByProvider };
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
import type { SpeechGenerationRequest } from "./types";
|
||||
|
||||
export async function generateSpeech(
|
||||
model: string,
|
||||
input: string,
|
||||
voice: string,
|
||||
signal?: AbortSignal
|
||||
): Promise<Blob> {
|
||||
const request: SpeechGenerationRequest = {
|
||||
model,
|
||||
input,
|
||||
voice,
|
||||
};
|
||||
|
||||
const response = await fetch("/v1/audio/speech", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Speech API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
return response.blob();
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
export type ConnectionState = "connected" | "connecting" | "disconnected";
|
||||
|
||||
export type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
|
||||
|
||||
export interface Model {
|
||||
id: string;
|
||||
state: ModelStatus;
|
||||
name: string;
|
||||
description: string;
|
||||
unlisted: boolean;
|
||||
peerID: string;
|
||||
}
|
||||
|
||||
export interface Metrics {
|
||||
id: number;
|
||||
timestamp: string;
|
||||
model: string;
|
||||
cache_tokens: number;
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
prompt_per_second: number;
|
||||
tokens_per_second: number;
|
||||
duration_ms: number;
|
||||
has_capture: boolean;
|
||||
}
|
||||
|
||||
export interface ReqRespCapture {
|
||||
id: number;
|
||||
req_path: string;
|
||||
req_headers: Record<string, string>;
|
||||
req_body: string; // base64 encoded bytes
|
||||
resp_headers: Record<string, string>;
|
||||
resp_body: string; // base64 encoded bytes
|
||||
}
|
||||
|
||||
export interface LogData {
|
||||
source: "upstream" | "proxy";
|
||||
data: string;
|
||||
}
|
||||
|
||||
export interface InFlightStats {
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface APIEventEnvelope {
|
||||
type: "modelStatus" | "logData" | "metrics" | "inflight";
|
||||
data: string;
|
||||
}
|
||||
|
||||
export interface VersionInfo {
|
||||
build_date: string;
|
||||
commit: string;
|
||||
version: string;
|
||||
}
|
||||
|
||||
export type ScreenWidth = "xs" | "sm" | "md" | "lg" | "xl" | "2xl";
|
||||
|
||||
export type TextContentPart = {
|
||||
type: "text";
|
||||
text: string;
|
||||
};
|
||||
|
||||
export type ImageContentPart = {
|
||||
type: "image_url";
|
||||
image_url: { url: string };
|
||||
};
|
||||
|
||||
export type ContentPart = TextContentPart | ImageContentPart;
|
||||
|
||||
export interface ChatMessage {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string | ContentPart[];
|
||||
reasoning_content?: string;
|
||||
reasoningTimeMs?: number;
|
||||
}
|
||||
|
||||
export function getTextContent(content: string | ContentPart[]): string {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
}
|
||||
const textParts = content.filter((part): part is TextContentPart => part.type === "text");
|
||||
return textParts.map((part) => part.text).join("\n");
|
||||
}
|
||||
|
||||
export function getImageUrls(content: string | ContentPart[]): string[] {
|
||||
if (typeof content === "string") {
|
||||
return [];
|
||||
}
|
||||
return content
|
||||
.filter((part): part is ImageContentPart => part.type === "image_url")
|
||||
.map((part) => part.image_url.url);
|
||||
}
|
||||
|
||||
export interface ChatCompletionRequest {
|
||||
model: string;
|
||||
messages: ChatMessage[];
|
||||
stream: boolean;
|
||||
temperature?: number;
|
||||
max_tokens?: number;
|
||||
}
|
||||
|
||||
export interface ImageGenerationRequest {
|
||||
model: string;
|
||||
prompt: string;
|
||||
n?: number;
|
||||
size?: string;
|
||||
}
|
||||
|
||||
export interface ImageGenerationResponse {
|
||||
created: number;
|
||||
data: Array<{
|
||||
url?: string;
|
||||
b64_json?: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
export interface AudioTranscriptionRequest {
|
||||
file: File;
|
||||
model: string;
|
||||
}
|
||||
|
||||
export interface AudioTranscriptionResponse {
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface SpeechGenerationRequest {
|
||||
model: string;
|
||||
input: string;
|
||||
voice: string;
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
import "./index.css";
|
||||
import "highlight.js/styles/github-dark.css";
|
||||
import App from "./App.svelte";
|
||||
import { mount } from "svelte";
|
||||
|
||||
const app = mount(App, {
|
||||
target: document.getElementById("app")!,
|
||||
});
|
||||
|
||||
export default app;
|
||||
@@ -0,0 +1,125 @@
|
||||
<script lang="ts">
|
||||
import { metrics, getCapture } from "../stores/api";
|
||||
import Tooltip from "../components/Tooltip.svelte";
|
||||
import CaptureDialog from "../components/CaptureDialog.svelte";
|
||||
import type { ReqRespCapture } from "../lib/types";
|
||||
|
||||
function formatSpeed(speed: number): string {
|
||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||
}
|
||||
|
||||
function formatDuration(ms: number): string {
|
||||
return (ms / 1000).toFixed(2) + "s";
|
||||
}
|
||||
|
||||
function formatRelativeTime(timestamp: string): string {
|
||||
const now = new Date();
|
||||
const date = new Date(timestamp);
|
||||
const diffInSeconds = Math.floor((now.getTime() - date.getTime()) / 1000);
|
||||
|
||||
// Handle future dates by returning "just now"
|
||||
if (diffInSeconds < 5) {
|
||||
return "now";
|
||||
}
|
||||
|
||||
if (diffInSeconds < 60) {
|
||||
return `${diffInSeconds}s ago`;
|
||||
}
|
||||
|
||||
const diffInMinutes = Math.floor(diffInSeconds / 60);
|
||||
if (diffInMinutes < 60) {
|
||||
return `${diffInMinutes}m ago`;
|
||||
}
|
||||
|
||||
const diffInHours = Math.floor(diffInMinutes / 60);
|
||||
if (diffInHours < 24) {
|
||||
return `${diffInHours}h ago`;
|
||||
}
|
||||
|
||||
return "a while ago";
|
||||
}
|
||||
|
||||
let sortedMetrics = $derived([...$metrics].sort((a, b) => b.id - a.id));
|
||||
|
||||
let selectedCapture = $state<ReqRespCapture | null>(null);
|
||||
let dialogOpen = $state(false);
|
||||
let loadingCaptureId = $state<number | null>(null);
|
||||
|
||||
async function viewCapture(id: number) {
|
||||
loadingCaptureId = id;
|
||||
const capture = await getCapture(id);
|
||||
loadingCaptureId = null;
|
||||
if (capture) {
|
||||
selectedCapture = capture;
|
||||
dialogOpen = true;
|
||||
}
|
||||
}
|
||||
|
||||
function closeDialog() {
|
||||
dialogOpen = false;
|
||||
selectedCapture = null;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="p-2">
|
||||
<h1 class="text-2xl font-bold">Activity</h1>
|
||||
|
||||
{#if $metrics.length === 0}
|
||||
<div class="text-center py-8">
|
||||
<p class="text-gray-600">No metrics data available</p>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="card overflow-auto">
|
||||
<table class="min-w-full divide-y">
|
||||
<thead class="border-gray-200 dark:border-white/10">
|
||||
<tr class="text-left text-xs uppercase tracking-wider">
|
||||
<th class="px-6 py-3">ID</th>
|
||||
<th class="px-6 py-3">Time</th>
|
||||
<th class="px-6 py-3">Model</th>
|
||||
<th class="px-6 py-3">
|
||||
Cached <Tooltip content="prompt tokens from cache" />
|
||||
</th>
|
||||
<th class="px-6 py-3">
|
||||
Prompt <Tooltip content="new prompt tokens processed" />
|
||||
</th>
|
||||
<th class="px-6 py-3">Generated</th>
|
||||
<th class="px-6 py-3">Prompt Processing</th>
|
||||
<th class="px-6 py-3">Generation Speed</th>
|
||||
<th class="px-6 py-3">Duration</th>
|
||||
<th class="px-6 py-3">Capture</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="divide-y">
|
||||
{#each sortedMetrics as metric (metric.id)}
|
||||
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
||||
<td class="px-4 py-4">{metric.id + 1}</td>
|
||||
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
|
||||
<td class="px-6 py-4">{metric.model}</td>
|
||||
<td class="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td>
|
||||
<td class="px-6 py-4">{metric.input_tokens.toLocaleString()}</td>
|
||||
<td class="px-6 py-4">{metric.output_tokens.toLocaleString()}</td>
|
||||
<td class="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td>
|
||||
<td class="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td>
|
||||
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
|
||||
<td class="px-6 py-4">
|
||||
{#if metric.has_capture}
|
||||
<button
|
||||
onclick={() => viewCapture(metric.id)}
|
||||
disabled={loadingCaptureId === metric.id}
|
||||
class="btn btn--sm"
|
||||
>
|
||||
{loadingCaptureId === metric.id ? "..." : "View"}
|
||||
</button>
|
||||
{:else}
|
||||
<span class="text-txtsecondary">-</span>
|
||||
{/if}
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} />
|
||||
@@ -0,0 +1,75 @@
|
||||
<script lang="ts">
|
||||
import { proxyLogs, upstreamLogs } from "../stores/api";
|
||||
import { screenWidth } from "../stores/theme";
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
import LogPanel from "../components/LogPanel.svelte";
|
||||
import ResizablePanels from "../components/ResizablePanels.svelte";
|
||||
|
||||
type ViewMode = "proxy" | "upstream" | "panels";
|
||||
|
||||
const viewModeStore = persistentStore<ViewMode>("logviewer-view-mode", "panels");
|
||||
|
||||
let direction = $derived<"horizontal" | "vertical">(
|
||||
$screenWidth === "xs" || $screenWidth === "sm" ? "vertical" : "horizontal"
|
||||
);
|
||||
|
||||
function cycleViewMode(): void {
|
||||
const modes: ViewMode[] = ["panels", "proxy", "upstream"];
|
||||
const currentIndex = modes.indexOf($viewModeStore);
|
||||
const nextIndex = (currentIndex + 1) % modes.length;
|
||||
viewModeStore.set(modes[nextIndex]);
|
||||
}
|
||||
|
||||
function getViewModeIcon(mode: ViewMode): string {
|
||||
switch (mode) {
|
||||
case "proxy":
|
||||
return "P";
|
||||
case "upstream":
|
||||
return "U";
|
||||
case "panels":
|
||||
return "⊞";
|
||||
}
|
||||
}
|
||||
|
||||
function getViewModeLabel(mode: ViewMode): string {
|
||||
switch (mode) {
|
||||
case "proxy":
|
||||
return "Proxy";
|
||||
case "upstream":
|
||||
return "Upstream";
|
||||
case "panels":
|
||||
return "Panels";
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full w-full gap-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<button
|
||||
onclick={cycleViewMode}
|
||||
class="btn flex items-center gap-2 text-sm"
|
||||
title="Toggle view mode"
|
||||
aria-label="Toggle view mode: {getViewModeLabel($viewModeStore)}"
|
||||
>
|
||||
<span class="font-mono font-bold">{getViewModeIcon($viewModeStore)}</span>
|
||||
<span>{getViewModeLabel($viewModeStore)}</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="flex-1 w-full overflow-hidden">
|
||||
{#if $viewModeStore === "panels"}
|
||||
<ResizablePanels {direction} storageKey="logviewer-panel-group">
|
||||
{#snippet leftPanel()}
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={$proxyLogs} />
|
||||
{/snippet}
|
||||
{#snippet rightPanel()}
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={$upstreamLogs} />
|
||||
{/snippet}
|
||||
</ResizablePanels>
|
||||
{:else if $viewModeStore === "proxy"}
|
||||
<LogPanel id="proxy" title="Proxy Logs" logData={$proxyLogs} />
|
||||
{:else}
|
||||
<LogPanel id="upstream" title="Upstream Logs" logData={$upstreamLogs} />
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,26 @@
|
||||
<script lang="ts">
|
||||
import { isNarrow } from "../stores/theme";
|
||||
import { upstreamLogs } from "../stores/api";
|
||||
import ModelsPanel from "../components/ModelsPanel.svelte";
|
||||
import StatsPanel from "../components/StatsPanel.svelte";
|
||||
import LogPanel from "../components/LogPanel.svelte";
|
||||
import ResizablePanels from "../components/ResizablePanels.svelte";
|
||||
|
||||
let direction = $derived<"horizontal" | "vertical">($isNarrow ? "vertical" : "horizontal");
|
||||
</script>
|
||||
|
||||
<ResizablePanels {direction} storageKey="models-panel-group">
|
||||
{#snippet leftPanel()}
|
||||
<ModelsPanel />
|
||||
{/snippet}
|
||||
{#snippet rightPanel()}
|
||||
<div class="flex flex-col h-full space-y-4">
|
||||
{#if direction === "horizontal"}
|
||||
<StatsPanel />
|
||||
{/if}
|
||||
<div class="flex-1 min-h-0">
|
||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
|
||||
</div>
|
||||
</div>
|
||||
{/snippet}
|
||||
</ResizablePanels>
|
||||
@@ -0,0 +1,99 @@
|
||||
<script lang="ts">
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
import ChatInterface from "../components/playground/ChatInterface.svelte";
|
||||
import ImageInterface from "../components/playground/ImageInterface.svelte";
|
||||
import AudioInterface from "../components/playground/AudioInterface.svelte";
|
||||
import SpeechInterface from "../components/playground/SpeechInterface.svelte";
|
||||
|
||||
type Tab = "chat" | "images" | "speech" | "audio";
|
||||
|
||||
const selectedTabStore = persistentStore<Tab>("playground-selected-tab", "chat");
|
||||
let mobileMenuOpen = $state(false);
|
||||
|
||||
const tabs: { id: Tab; label: string }[] = [
|
||||
{ id: "chat", label: "Chat" },
|
||||
{ id: "images", label: "Images" },
|
||||
{ id: "speech", label: "Speech" },
|
||||
{ id: "audio", label: "Transcription" },
|
||||
];
|
||||
|
||||
function selectTab(tab: Tab) {
|
||||
selectedTabStore.set(tab);
|
||||
mobileMenuOpen = false;
|
||||
}
|
||||
|
||||
function getTabLabel(tabId: Tab): string {
|
||||
return tabs.find(t => t.id === tabId)?.label || "";
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="card h-full flex flex-col">
|
||||
<!-- Tab navigation -->
|
||||
<div class="shrink-0 mb-4">
|
||||
<!-- Mobile: Dropdown menu (hidden on md and up) -->
|
||||
<div class="block md:hidden relative">
|
||||
<button
|
||||
class="w-full px-4 py-2 rounded font-medium transition-colors flex items-center justify-between bg-surface hover:bg-secondary-hover border border-gray-200 dark:border-white/10"
|
||||
onclick={() => (mobileMenuOpen = !mobileMenuOpen)}
|
||||
>
|
||||
<span>{getTabLabel($selectedTabStore)}</span>
|
||||
<svg
|
||||
class="w-5 h-5 transition-transform {mobileMenuOpen ? 'rotate-180' : ''}"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7"></path>
|
||||
</svg>
|
||||
</button>
|
||||
{#if mobileMenuOpen}
|
||||
<div class="absolute top-full left-0 right-0 mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10">
|
||||
{#each tabs as tab (tab.id)}
|
||||
<button
|
||||
class="w-full px-4 py-2 text-left hover:bg-secondary-hover transition-colors first:rounded-t last:rounded-b {$selectedTabStore === tab.id ? 'bg-primary/10 font-medium' : ''}"
|
||||
onclick={() => selectTab(tab.id)}
|
||||
>
|
||||
{tab.label}
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Desktop: Tab buttons (shown on md and up) -->
|
||||
<div class="hidden md:flex flex-wrap gap-2">
|
||||
{#each tabs as tab (tab.id)}
|
||||
<button
|
||||
class="px-4 py-2 rounded font-medium transition-colors {$selectedTabStore === tab.id
|
||||
? 'bg-primary text-btn-primary-text'
|
||||
: 'bg-surface hover:bg-secondary-hover border border-gray-200 dark:border-white/10'}"
|
||||
onclick={() => selectTab(tab.id)}
|
||||
>
|
||||
{tab.label}
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Tab content -->
|
||||
<div class="flex-1 overflow-hidden relative">
|
||||
<div class="h-full" class:tab-hidden={$selectedTabStore !== "chat"}>
|
||||
<ChatInterface />
|
||||
</div>
|
||||
<div class="h-full" class:tab-hidden={$selectedTabStore !== "images"}>
|
||||
<ImageInterface />
|
||||
</div>
|
||||
<div class="h-full" class:tab-hidden={$selectedTabStore !== "speech"}>
|
||||
<SpeechInterface />
|
||||
</div>
|
||||
<div class="h-full" class:tab-hidden={$selectedTabStore !== "audio"}>
|
||||
<AudioInterface />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.tab-hidden {
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1 @@
|
||||
<!-- empty: real Playground is always mounted in App.svelte -->
|
||||
@@ -0,0 +1,198 @@
|
||||
import { writable } from "svelte/store";
|
||||
import type { Model, Metrics, VersionInfo, LogData, APIEventEnvelope, ReqRespCapture, InFlightStats } from "../lib/types";
|
||||
import { connectionState } from "./theme";
|
||||
|
||||
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||
|
||||
// Stores
|
||||
export const models = writable<Model[]>([]);
|
||||
export const proxyLogs = writable<string>("");
|
||||
export const upstreamLogs = writable<string>("");
|
||||
export const metrics = writable<Metrics[]>([]);
|
||||
export const inFlightRequests = writable<number>(0);
|
||||
export const versionInfo = writable<VersionInfo>({
|
||||
build_date: "unknown",
|
||||
commit: "unknown",
|
||||
version: "unknown",
|
||||
});
|
||||
|
||||
let apiEventSource: EventSource | null = null;
|
||||
|
||||
function appendLog(newData: string, store: typeof proxyLogs | typeof upstreamLogs): void {
|
||||
store.update((prev) => {
|
||||
const updatedLog = prev + newData;
|
||||
return updatedLog.length > LOG_LENGTH_LIMIT ? updatedLog.slice(-LOG_LENGTH_LIMIT) : updatedLog;
|
||||
});
|
||||
}
|
||||
|
||||
export function enableAPIEvents(enabled: boolean): void {
|
||||
if (!enabled) {
|
||||
apiEventSource?.close();
|
||||
apiEventSource = null;
|
||||
metrics.set([]);
|
||||
inFlightRequests.set(0);
|
||||
return;
|
||||
}
|
||||
|
||||
let retryCount = 0;
|
||||
const initialDelay = 1000; // 1 second
|
||||
|
||||
const connect = () => {
|
||||
apiEventSource?.close();
|
||||
apiEventSource = new EventSource("/api/events");
|
||||
|
||||
connectionState.set("connecting");
|
||||
|
||||
apiEventSource.onopen = () => {
|
||||
// Clear everything on connect to keep things in sync
|
||||
proxyLogs.set("");
|
||||
upstreamLogs.set("");
|
||||
metrics.set([]);
|
||||
inFlightRequests.set(0);
|
||||
models.set([]);
|
||||
retryCount = 0;
|
||||
connectionState.set("connected");
|
||||
};
|
||||
|
||||
apiEventSource.onmessage = (e: MessageEvent) => {
|
||||
try {
|
||||
const message = JSON.parse(e.data) as APIEventEnvelope;
|
||||
switch (message.type) {
|
||||
case "modelStatus": {
|
||||
const newModels = JSON.parse(message.data) as Model[];
|
||||
// Sort models by name and id
|
||||
newModels.sort((a, b) => {
|
||||
return (a.name + a.id).localeCompare(b.name + b.id);
|
||||
});
|
||||
models.set(newModels);
|
||||
break;
|
||||
}
|
||||
|
||||
case "logData": {
|
||||
const logData = JSON.parse(message.data) as LogData;
|
||||
switch (logData.source) {
|
||||
case "proxy":
|
||||
appendLog(logData.data, proxyLogs);
|
||||
break;
|
||||
case "upstream":
|
||||
appendLog(logData.data, upstreamLogs);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "metrics": {
|
||||
const newMetrics = JSON.parse(message.data) as Metrics[];
|
||||
metrics.update((prevMetrics) => [...newMetrics, ...prevMetrics]);
|
||||
break;
|
||||
}
|
||||
case "inflight": {
|
||||
const stats = JSON.parse(message.data) as InFlightStats;
|
||||
inFlightRequests.set(stats.total ?? 0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(e.data, err);
|
||||
}
|
||||
};
|
||||
|
||||
apiEventSource.onerror = () => {
|
||||
apiEventSource?.close();
|
||||
retryCount++;
|
||||
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
|
||||
connectionState.set("disconnected");
|
||||
setTimeout(connect, delay);
|
||||
};
|
||||
};
|
||||
|
||||
connect();
|
||||
}
|
||||
|
||||
// Fetch version info when connected
|
||||
connectionState.subscribe(async (status) => {
|
||||
if (status === "connected") {
|
||||
try {
|
||||
const response = await fetch("/api/version");
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data: VersionInfo = await response.json();
|
||||
versionInfo.set(data);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
export async function listModels(): Promise<Model[]> {
|
||||
try {
|
||||
const response = await fetch("/api/models/");
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data = await response.json();
|
||||
return data || [];
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch models:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export async function unloadAllModels(): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(`/api/models/unload`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to unload models: ${response.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to unload models:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function unloadSingleModel(model: string): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(`/api/models/unload/${model}`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to unload model: ${response.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to unload model", model, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function loadModel(model: string): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(`/upstream/${model}/`, {
|
||||
method: "GET",
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to load model: ${response.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to load model:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function getCapture(id: number): Promise<ReqRespCapture | null> {
|
||||
try {
|
||||
const response = await fetch(`/api/captures/${id}`);
|
||||
if (response.status === 404) {
|
||||
return null;
|
||||
}
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch capture: ${response.status}`);
|
||||
}
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch capture:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
import { writable, type Writable } from "svelte/store";
|
||||
|
||||
export function persistentStore<T>(key: string, initialValue: T): Writable<T> {
|
||||
// Get initial value from localStorage or use default
|
||||
let storedValue = initialValue;
|
||||
if (typeof window !== "undefined") {
|
||||
try {
|
||||
const saved = localStorage.getItem(key);
|
||||
if (saved !== null) {
|
||||
storedValue = JSON.parse(saved);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(`Error parsing stored value for ${key}`, e);
|
||||
}
|
||||
}
|
||||
|
||||
const store = writable<T>(storedValue);
|
||||
|
||||
// Subscribe to changes and save to localStorage
|
||||
store.subscribe((value) => {
|
||||
if (typeof window !== "undefined") {
|
||||
try {
|
||||
localStorage.setItem(key, JSON.stringify(value));
|
||||
} catch (e) {
|
||||
console.error(`Error saving value for ${key}`, e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return store;
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
import { writable, derived } from "svelte/store";
|
||||
|
||||
const chatStreaming = writable(false);
|
||||
const imageGenerating = writable(false);
|
||||
const speechGenerating = writable(false);
|
||||
const audioTranscribing = writable(false);
|
||||
|
||||
export const playgroundActivity = derived(
|
||||
[chatStreaming, imageGenerating, speechGenerating, audioTranscribing],
|
||||
([$chat, $image, $speech, $audio]) => $chat || $image || $speech || $audio
|
||||
);
|
||||
|
||||
export const playgroundStores = {
|
||||
chatStreaming,
|
||||
imageGenerating,
|
||||
speechGenerating,
|
||||
audioTranscribing,
|
||||
};
|
||||