Compare commits
102 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a9d840ffd7 | |||
| 7b2b82777f | |||
| d87f0ce2c5 | |||
| 06bc6a614c | |||
| a37b4866d8 | |||
| 981910d734 | |||
| a185efe37e | |||
| 1dd1aadf93 | |||
| 955900972a | |||
| c2c8cfaf81 | |||
| 1e440770ea | |||
| c794273c83 | |||
| 6574a52cbb | |||
| 8fabc75634 | |||
| e5e7391b6d | |||
| 2c282dccad | |||
| 916d13f5bd | |||
| a3725e7d09 | |||
| 15bd55d3a9 | |||
| c3c258a55d | |||
| 29a38fde0d | |||
| d569681daa | |||
| 24efdb76b1 | |||
| cc77139ff8 | |||
| 390a35bf93 | |||
| 181f71ca11 | |||
| 49546e2cf2 | |||
| 2c078964f4 | |||
| 175bb36fb1 | |||
| aedb640471 | |||
| 2f377f6dc6 | |||
| 64e4c79fc3 | |||
| 19fb5f35e9 | |||
| b45102bde8 | |||
| 1688bdd1e9 | |||
| d33d51fa75 | |||
| e3bf065574 | |||
| 3e52144058 | |||
| d5e52d7d00 | |||
| 17e5263a76 | |||
| 8d6d949ec3 | |||
| b5fde8eb6d | |||
| 7eef5defb8 | |||
| bc01e6f539 | |||
| 0462e3dc3f | |||
| 7b20fc011b | |||
| 20738f3623 | |||
| cdea7d16bd | |||
| 5de387dbf9 | |||
| 6f8e7ccb57 | |||
| 4384315b44 | |||
| 6439ab1515 | |||
| f94226122c | |||
| 7493618fdc | |||
| 205efd40a1 | |||
| 14207f8492 | |||
| 4e850c2834 | |||
| 75fced579e | |||
| b73f367f22 | |||
| 8f2137c72b | |||
| 124007cc98 | |||
| eb5bfff0b0 | |||
| 3edb180c08 | |||
| 66d555e625 | |||
| 4f863fd9fc | |||
| 267c030457 | |||
| c19309fe7e | |||
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f | |||
| 37d74efc2d | |||
| 22e098ac8b | |||
| 9864f9f517 | |||
| 53b32f3601 | |||
| 565c44766d | |||
| e6a9e210ba | |||
| d3f329f924 | |||
| 98879b38c1 | |||
| 7b3b0f5eae | |||
| 021ccceef1 | |||
| f03871c50a | |||
| dc00d17abe | |||
| dea98733c3 | |||
| bccce5fa19 | |||
| c968da1b73 | |||
| a883d68d4f | |||
| b1dec8b735 | |||
| 06523d8c1e | |||
| 86e9b93c37 | |||
| 3acace810f | |||
| 554d29e87d | |||
| 3567b7df08 | |||
| 38738525c9 | |||
| c0fc858193 | |||
| b429349e8a | |||
| eab2efd7b5 | |||
| 6aedbe121a | |||
| b24467ab89 | |||
| 12b69fb718 | |||
| f91a8b2462 |
@@ -4,12 +4,19 @@ early_access: false
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: true
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
issue_enrichment:
|
||||
planning:
|
||||
enabled: false
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
name: Validate JSON Schema
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
- "config.example.yaml"
|
||||
- ".github/workflows/config-schema.yml"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
- "config.example.yaml"
|
||||
- ".github/workflows/config-schema.yml"
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
validate-schema:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Validate JSON Schema
|
||||
run: |
|
||||
# Check if the file is valid JSON
|
||||
if ! jq empty config-schema.json 2>/dev/null; then
|
||||
echo "Error: config-schema.json is not valid JSON"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate that it's a valid JSON Schema
|
||||
# Check for required $schema field
|
||||
if ! jq -e '."$schema"' config-schema.json > /dev/null; then
|
||||
echo "Warning: config-schema.json should have a \$schema field"
|
||||
fi
|
||||
|
||||
# Check that it has either properties or definitions
|
||||
if ! jq -e '.properties or .definitions or ."$defs"' config-schema.json > /dev/null; then
|
||||
echo "Warning: JSON Schema should contain properties, definitions, or \$defs"
|
||||
fi
|
||||
|
||||
echo "✓ config-schema.json is valid"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Install check-jsonschema
|
||||
run: pip install check-jsonschema
|
||||
|
||||
- name: Validate config.example.yaml against schema
|
||||
run: check-jsonschema --schemafile config-schema.json config.example.yaml
|
||||
@@ -10,17 +10,44 @@ on:
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on workflow file changes (without pushing)
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/containers.yml'
|
||||
- 'docker/build-container.sh'
|
||||
- 'docker/*.Containerfile'
|
||||
|
||||
# grant permissions on GITHUB_TOKEN to publish packages
|
||||
# ref: https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#publishing-a-package-using-an-action
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
platform: [intel, cuda, cuda13, vulkan, cpu, musa, rocm]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
if: matrix.platform == 'rocm'
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -31,7 +58,7 @@ jobs:
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
|
||||
@@ -3,9 +3,25 @@ name: Windows CI
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
@@ -28,7 +44,7 @@ jobs:
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
@@ -43,7 +59,7 @@ jobs:
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
shell: bash
|
||||
|
||||
@@ -3,9 +3,25 @@ name: Linux CI
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
@@ -20,7 +36,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version-file: go.mod
|
||||
|
||||
# Only run in this linux based runner
|
||||
- name: Check Formatting
|
||||
@@ -35,7 +51,7 @@ jobs:
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
@@ -51,4 +67,4 @@ jobs:
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
run: make test-all
|
||||
|
||||
@@ -3,13 +3,13 @@ name: goreleaser
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag version to release (e.g. v144)'
|
||||
description: "Tag version to release (e.g. v144)"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
@@ -19,35 +19,30 @@ jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
-
|
||||
name: Set up Go
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
-
|
||||
name: Set up Node.js
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '23'
|
||||
-
|
||||
name: Install dependencies and build UI
|
||||
node-version: "24"
|
||||
- name: Install dependencies and build UI
|
||||
run: |
|
||||
cd ui
|
||||
cd ui-svelte
|
||||
npm ci
|
||||
npm run build
|
||||
|
||||
-
|
||||
name: Run GoReleaser
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
# either 'goreleaser' (default) or 'goreleaser-pro'
|
||||
distribution: goreleaser
|
||||
# 'latest', 'nightly', or a semver
|
||||
version: '~> v2'
|
||||
version: "~> v2"
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -76,4 +71,4 @@ jobs:
|
||||
"release": {
|
||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
name: UI Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ui-svelte
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '24'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: ui-svelte/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Type check
|
||||
run: npm run check
|
||||
|
||||
- name: Run tests
|
||||
run: npm test
|
||||
@@ -0,0 +1,131 @@
|
||||
name: Build Unified Docker Image
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
llama_cpp_ref:
|
||||
description: "llama.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
whisper_ref:
|
||||
description: "whisper.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
sd_ref:
|
||||
description: "stable-diffusion.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
ik_llama_ref:
|
||||
description: "ik_llama.cpp commit hash, tag, or branch (CUDA only)"
|
||||
required: false
|
||||
default: "main"
|
||||
llama_swap_version:
|
||||
description: "llama-swap version (e.g. v198, latest, main)"
|
||||
required: false
|
||||
default: "main"
|
||||
build_cuda:
|
||||
description: "Build CUDA image"
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
build_vulkan:
|
||||
description: "Build Vulkan image"
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- id: set-matrix
|
||||
run: |
|
||||
backends=()
|
||||
# schedule uses defaults (build both); workflow_dispatch respects inputs
|
||||
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_cuda }}" == "true" ]]; then
|
||||
backends+=("cuda")
|
||||
fi
|
||||
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_vulkan }}" == "true" ]]; then
|
||||
backends+=("vulkan")
|
||||
fi
|
||||
matrix=$(printf '%s\n' "${backends[@]}" | jq -R . | jq -sc .)
|
||||
echo "matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
|
||||
build:
|
||||
needs: setup
|
||||
if: ${{ needs.setup.outputs.matrix != '[]' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
# On GitHub Actions runners, create a fresh builder.
|
||||
# When running locally under act, skip this and reuse the existing
|
||||
# llama-swap-builder (which has ccache warm) to avoid exhausting disk.
|
||||
- name: Set up Docker Buildx
|
||||
if: ${{ !env.ACT }}
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
if: ${{ !env.ACT }}
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build unified Docker image (${{ matrix.backend }})
|
||||
env:
|
||||
LLAMA_REF: ${{ inputs.llama_cpp_ref || 'master' }}
|
||||
WHISPER_REF: ${{ inputs.whisper_ref || 'master' }}
|
||||
SD_REF: ${{ inputs.sd_ref || 'master' }}
|
||||
IK_LLAMA_REF: ${{ inputs.ik_llama_ref || 'main' }}
|
||||
LS_VERSION: ${{ inputs.llama_swap_version || 'main' }}
|
||||
DOCKER_IMAGE_TAG: ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}
|
||||
# When running under act, use the local builder that has warm ccache.
|
||||
# On GitHub Actions, BUILDX_BUILDER is unset so docker uses the builder
|
||||
# created by setup-buildx-action above.
|
||||
BUILDX_BUILDER: ${{ env.ACT == 'true' && 'llama-swap-builder' || '' }}
|
||||
run: |
|
||||
chmod +x docker/unified/build-image.sh
|
||||
docker/unified/build-image.sh --${{ matrix.backend }}
|
||||
|
||||
- name: Push to GitHub Container Registry
|
||||
if: ${{ !env.ACT }}
|
||||
run: |
|
||||
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
|
||||
DATE_TAG=$(date -u +%Y-%m-%d)
|
||||
|
||||
docker push "${BASE_TAG}"
|
||||
docker tag "${BASE_TAG}" "${BASE_TAG}-${DATE_TAG}"
|
||||
docker push "${BASE_TAG}-${DATE_TAG}"
|
||||
|
||||
ROOTLESS_TAG="${BASE_TAG}-rootless"
|
||||
docker push "${ROOTLESS_TAG}"
|
||||
docker tag "${ROOTLESS_TAG}" "${ROOTLESS_TAG}-${DATE_TAG}"
|
||||
docker push "${ROOTLESS_TAG}-${DATE_TAG}"
|
||||
@@ -0,0 +1,51 @@
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and svelt5 for UI (located in ui/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use the github CLI `gh` to create pull requests and work with github
|
||||
- Rules for creating pull requests:
|
||||
- keep them short and focused on changes.
|
||||
- never include a test plan
|
||||
- write the summary using the same style rules as commit message
|
||||
|
||||
## Testing
|
||||
|
||||
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
|
||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||
|
||||
### Commit message example format:
|
||||
|
||||
```
|
||||
proxy: add new feature
|
||||
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
- key change 1
|
||||
- key change 2
|
||||
- key change 3
|
||||
|
||||
fixes #123
|
||||
```
|
||||
|
||||
## Code Reviews
|
||||
|
||||
- use three levels High, Medium, Low severity
|
||||
- label each discovered issue with a label like H1, M2, L3 respectively
|
||||
- High severity are must fix issues (security, race conditions, critical bugs)
|
||||
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
|
||||
- Low severity are nice to have changes and nits
|
||||
- Include a suggestion with each discovered item
|
||||
- Limit your code review to three items with the highest priority first
|
||||
- Double check your discovered items and recommended remediations
|
||||
@@ -1,43 +1 @@
|
||||
# Project: llama-swap
|
||||
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and react for UI (ui/)
|
||||
|
||||
## Testing
|
||||
|
||||
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors.
|
||||
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
### Plan Improvements
|
||||
|
||||
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
||||
|
||||
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
||||
|
||||
- Identify any inconsistencies.
|
||||
- Expand plans out to be detailed specification of requirements and changes to be made.
|
||||
- Plans should have at least these sections:
|
||||
- Title - very short, describes changes
|
||||
- Overview: A more detailed summary of goal and outcomes desired
|
||||
- Design Requirements: Detailed descriptions of what needs to be done
|
||||
- Testing Plan: Tests to be implemented
|
||||
- Checklist: A detailed list of changes to be made
|
||||
|
||||
Look for "plan expansion" as explicit instructions to improve a plan.
|
||||
|
||||
### Implementation of plans
|
||||
|
||||
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
||||
|
||||
## General Rules
|
||||
|
||||
- when summarizing changes only include details that require further action (action items)
|
||||
- when there are no action items, just say "Done."
|
||||
@AGENTS.md
|
||||
|
||||
@@ -36,11 +36,11 @@ test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -race -count=1 ./proxy/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui && npm install
|
||||
cd ui-svelte && npm install
|
||||
|
||||
# build react UI
|
||||
ui: ui/node_modules
|
||||
cd ui && npm run build
|
||||
cd ui-svelte && npm run build
|
||||
|
||||
# Build OSX binary
|
||||
mac: ui
|
||||
@@ -51,7 +51,7 @@ mac: ui
|
||||
linux: ui
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||
#GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||
|
||||
# Build Windows binary
|
||||
windows: ui
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
# llama-swap
|
||||
|
||||
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
|
||||
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
|
||||
|
||||
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
||||
|
||||
@@ -13,18 +13,29 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
|
||||
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
|
||||
- future proof, upgrade your inference servers at any time.
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/responses`
|
||||
- `v1/embeddings`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- `v1/audio/voices`
|
||||
- `v1/images/generations`
|
||||
- `v1/images/edits`
|
||||
- ✅ Anthropic API supported endpoints:
|
||||
- `v1/messages`
|
||||
- `v1/messages/count_tokens`
|
||||
- ✅ llama-server (llama.cpp) supported endpoints
|
||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||
- `/infill` - for code infilling
|
||||
- `/completion` - for completion endpoint
|
||||
- ✅ SDAPI via [stable-diffusion.cpp's server](https://github.com/leejet/stable-diffusion.cpp/tree/master/examples/server)
|
||||
- `/sdapi/v1/txt2img`
|
||||
- `/sdapi/v1/img2img`
|
||||
- `/sdapi/v1/loras` - requires `model` in request body to fetch the correct loras
|
||||
- ✅ llama-swap API
|
||||
- `/ui` - web UI
|
||||
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
@@ -32,6 +43,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- `/log` - remote log monitoring
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||
- ✅ Customizable
|
||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- Automatic unloading of models after timeout by setting a `ttl`
|
||||
@@ -40,13 +52,27 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
|
||||
### Web UI
|
||||
|
||||
llama-swap includes a real time web interface for monitoring logs and controlling models:
|
||||
llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
|
||||
|
||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/adef4a8e-de0b-49db-885a-8f6dedae6799" />
|
||||
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
|
||||
|
||||
The Activity Page shows recent requests:
|
||||
View detailed token metrics:
|
||||
|
||||
<img width="1111" height="515" alt="image" src="https://github.com/user-attachments/assets/64bfb280-d7a3-4126-971a-a128fd40410c" />
|
||||
|
||||
Inspect request and responses:
|
||||
|
||||
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
|
||||
|
||||
Manually load and unload models:
|
||||
|
||||
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
|
||||
|
||||
|
||||
Real time log streaming:
|
||||
|
||||
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
|
||||
|
||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -60,7 +86,8 @@ llama-swap can be installed in multiple ways
|
||||
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc).
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
|
||||
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
|
||||
|
||||
```shell
|
||||
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
@@ -70,6 +97,14 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
|
||||
# configuration hot reload supported with a
|
||||
# directory volume mount
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
-v /path/to/config:/config \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
|
||||
```
|
||||
|
||||
<details>
|
||||
@@ -88,6 +123,9 @@ docker pull ghcr.io/mostlygeek/llama-swap:musa
|
||||
# tagged llama-swap, platform and llama-server version images
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
||||
|
||||
# non-root cuda
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -190,23 +228,26 @@ As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. H
|
||||
|
||||
## Monitoring Logs on the CLI
|
||||
|
||||
```shell
|
||||
```sh
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
$ curl http://host/logs
|
||||
|
||||
# streams combined logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
curl -Ns http://host/logs/stream
|
||||
|
||||
# just llama-swap's logs
|
||||
curl -Ns 'http://host/logs/stream/proxy'
|
||||
# stream llama-swap's proxy status logs
|
||||
curl -Ns http://host/logs/stream/proxy
|
||||
|
||||
# just upstream's logs
|
||||
curl -Ns 'http://host/logs/stream/upstream'
|
||||
# stream logs from upstream processes that llama-swap loads
|
||||
curl -Ns http://host/logs/stream/upstream
|
||||
|
||||
# stream logs only from a specific model
|
||||
curl -Ns http://host/logs/stream/{model_id}
|
||||
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
# skips history and just streams new log entries
|
||||
# appending ?no-history will disable sending buffered history first
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# Replace ring.Ring with Efficient Circular Byte Buffer
|
||||
|
||||
## Overview
|
||||
|
||||
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
|
||||
|
||||
## Current Issues
|
||||
|
||||
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
|
||||
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
|
||||
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
|
||||
4. Linked list structure has poor cache locality and pointer overhead
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### New CircularBuffer Type
|
||||
|
||||
Create a simple circular byte buffer with:
|
||||
- Single pre-allocated `[]byte` of fixed capacity (10KB)
|
||||
- `head` and `size` integers to track write position and data length
|
||||
- No per-write allocations
|
||||
|
||||
### API Requirements
|
||||
|
||||
The new buffer must support:
|
||||
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
|
||||
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
|
||||
|
||||
### Implementation Details
|
||||
|
||||
```go
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
```
|
||||
|
||||
**Write logic:**
|
||||
- If `len(p) >= capacity`: just keep the last `capacity` bytes
|
||||
- Otherwise: write bytes at `head`, wrapping around if needed
|
||||
- Update `head` and `size` accordingly
|
||||
- Data is copied into the internal buffer (not stored by reference)
|
||||
|
||||
**GetHistory logic:**
|
||||
- Calculate start position: `(head - size + cap) % cap`
|
||||
- If not wrapped: single slice copy
|
||||
- If wrapped: two copies (end of buffer + beginning)
|
||||
- Returns a **new slice** (copy), not a view into internal buffer
|
||||
|
||||
### Immutability Guarantees (must preserve)
|
||||
|
||||
Per existing tests:
|
||||
1. Modifying input `[]byte` after `Write()` must not affect stored data
|
||||
2. `GetHistory()` returns independent copy - modifications don't affect buffer
|
||||
|
||||
## Files to Modify
|
||||
|
||||
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
|
||||
|
||||
## Testing Plan
|
||||
|
||||
Existing tests in `logMonitor_test.go` should continue to pass:
|
||||
- `TestLogMonitor` - Basic write/read and subscriber notification
|
||||
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
|
||||
- `TestWrite_LogTimeFormat` - Timestamp formatting
|
||||
|
||||
Add new tests:
|
||||
- Test buffer wrap-around behavior
|
||||
- Test large writes that exceed buffer capacity
|
||||
- Test exact capacity boundary conditions
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Create `circularBuffer` struct in `logMonitor.go`
|
||||
- [ ] Implement `Write()` method for circular buffer
|
||||
- [ ] Implement `GetHistory()` method for circular buffer
|
||||
- [ ] Update `LogMonitor` struct to use new buffer
|
||||
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
|
||||
- [ ] Update `LogMonitor.Write()` to use new buffer
|
||||
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
|
||||
- [ ] Remove `"container/ring"` import
|
||||
- [ ] Run `make test-dev` to verify existing tests pass
|
||||
- [ ] Add wrap-around test case
|
||||
- [ ] Run `make test-all` for final validation
|
||||
@@ -210,6 +210,11 @@ func main() {
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/v1/audio/voices", func(c *gin.Context) {
|
||||
model := c.Query("model")
|
||||
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
@@ -269,6 +274,43 @@ func main() {
|
||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||
})
|
||||
|
||||
// SD API endpoints
|
||||
r.POST("/sdapi/v1/txt2img", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": modelName,
|
||||
"images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
r.POST("/sdapi/v1/img2img", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": modelName,
|
||||
"images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/sdapi/v1/loras", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"loras": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
address := "127.0.0.1:" + *port // Address with the specified port
|
||||
|
||||
srv := &http.Server{
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Loading...</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.loader {
|
||||
text-align: center;
|
||||
}
|
||||
.stats {
|
||||
font-size: 18px;
|
||||
color: #333;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.stats-label {
|
||||
color: #666;
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="loader">
|
||||
<p>Waking up upstream server...</p>
|
||||
<div class="stats">
|
||||
<div><span class="stats-label">Time elapsed:</span> <span id="elapsed">0s</span></div>
|
||||
<div><span id="attempts"> </span></div>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
var startTime = Date.now();
|
||||
var attempts = 0;
|
||||
|
||||
setInterval(function() {
|
||||
var elapsed = (Date.now() - startTime) / 1000;
|
||||
document.getElementById('elapsed').textContent = elapsed.toFixed(1) + 's';
|
||||
}, 100);
|
||||
|
||||
// Check status every second
|
||||
setInterval(function() {
|
||||
attempts++;
|
||||
var dots = '.'.repeat((attempts % 10) || 10);
|
||||
document.getElementById('attempts').textContent = dots;
|
||||
|
||||
fetch('/status')
|
||||
.then(function(r) { return r.text(); })
|
||||
.then(function(t) {
|
||||
if (t.indexOf('status: ready') !== -1) {
|
||||
location.reload();
|
||||
}
|
||||
})
|
||||
.catch(function() {});
|
||||
}, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
@@ -19,6 +20,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:embed index.html
|
||||
var loadingPageHTML string
|
||||
|
||||
var (
|
||||
flagMac = flag.String("mac", "", "mac address to send WoL packet to")
|
||||
flagUpstream = flag.String("upstream", "", "upstream proxy address to send requests to")
|
||||
@@ -230,6 +234,16 @@ func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if err := sendMagicPacket(*flagMac); err != nil {
|
||||
slog.Warn("failed to send magic WoL packet", "error", err)
|
||||
}
|
||||
|
||||
// For root or UI path requests, return loading page with status polling
|
||||
// the web page will do the polling and redirect when ready
|
||||
if path == "/" || strings.HasPrefix(path, "/ui/") {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, loadingPageHTML)
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
timeout, cancel := context.WithTimeout(context.Background(), time.Duration(*flagTimeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -0,0 +1,460 @@
|
||||
{
|
||||
"$schema": "https://json-schema.org/draft-07/schema#",
|
||||
"$id": "llama-swap-config-schema.json",
|
||||
"title": "llama-swap configuration",
|
||||
"description": "Configuration file for llama-swap",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"models"
|
||||
],
|
||||
"definitions": {
|
||||
"macros": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"minLength": 0,
|
||||
"maxLength": 1024
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
},
|
||||
"propertyNames": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 64,
|
||||
"pattern": "^[a-zA-Z0-9_-]+$",
|
||||
"not": {
|
||||
"enum": [
|
||||
"PORT",
|
||||
"MODEL_ID"
|
||||
]
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
|
||||
},
|
||||
"timeouts": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connect": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP connection timeout in seconds. Set to 0 to disable."
|
||||
},
|
||||
"keepalive": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP keepalive timeout in seconds. Set to 0 to disable."
|
||||
},
|
||||
"responseHeader": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Time to wait for response headers in seconds. Set to 0 to disable."
|
||||
},
|
||||
"tlsHandshake": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 10,
|
||||
"description": "TLS handshake timeout in seconds. Set to 0 to disable."
|
||||
},
|
||||
"expectContinue": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 1,
|
||||
"description": "Expect-Continue timeout in seconds. Set to 0 to disable."
|
||||
},
|
||||
"idleConn": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 90,
|
||||
"description": "Idle connection timeout in seconds. Set to 0 to disable."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections."
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"healthCheckTimeout": {
|
||||
"type": "integer",
|
||||
"minimum": 15,
|
||||
"default": 120,
|
||||
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
||||
},
|
||||
"globalTTL": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Default TTL for all models in seconds, 0 means no TTL and models will never be automatically unloaded"
|
||||
},
|
||||
"logLevel": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"debug",
|
||||
"info",
|
||||
"warn",
|
||||
"error"
|
||||
],
|
||||
"default": "info",
|
||||
"description": "Sets the logging value. Valid values: debug, info, warn, error."
|
||||
},
|
||||
"logTimeFormat": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"",
|
||||
"ansic",
|
||||
"unixdate",
|
||||
"rubydate",
|
||||
"rfc822",
|
||||
"rfc822z",
|
||||
"rfc850",
|
||||
"rfc1123",
|
||||
"rfc1123z",
|
||||
"rfc3339",
|
||||
"rfc3339nano",
|
||||
"kitchen",
|
||||
"stamp",
|
||||
"stampmilli",
|
||||
"stampmicro",
|
||||
"stampnano"
|
||||
],
|
||||
"default": "",
|
||||
"description": "Enables and sets the logging timestamp format. Valid values: \"\", \"ansic\", \"unixdate\", \"rubydate\", \"rfc822\", \"rfc822z\", \"rfc850\", \"rfc1123\", \"rfc1123z\", \"rfc3339\", \"rfc3339nano\", \"kitchen\", \"stamp\", \"stampmilli\", \"stampmicro\", and \"stampnano\". For more info, read: https://pkg.go.dev/time#pkg-constants"
|
||||
},
|
||||
"metricsMaxInMemory": {
|
||||
"type": "integer",
|
||||
"default": 1000,
|
||||
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
||||
},
|
||||
"captureBuffer": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 5,
|
||||
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
|
||||
},
|
||||
"startPort": {
|
||||
"type": "integer",
|
||||
"default": 5800,
|
||||
"description": "Starting port number for the automatic ${PORT} macro. The ${PORT} macro is incremented for every model that uses it."
|
||||
},
|
||||
"sendLoadingState": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Inject loading status updates into the reasoning field. When true, a stream of loading messages will be sent to the client."
|
||||
},
|
||||
"includeAliasesInList": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Present aliases within the /v1/models OpenAI API listing. when true, model aliases will be output to the API model listing duplicating all fields except for Id so chat UIs can use the alias equivalent to the original."
|
||||
},
|
||||
"macros": {
|
||||
"$ref": "#/definitions/macros"
|
||||
},
|
||||
"models": {
|
||||
"type": "object",
|
||||
"description": "A dictionary of model configurations. Each key is a model's ID. Model settings have defaults if not defined. The model's ID is available as ${MODEL_ID}.",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"cmd"
|
||||
],
|
||||
"properties": {
|
||||
"macros": {
|
||||
"$ref": "#/definitions/macros"
|
||||
},
|
||||
"cmd": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "Command to run to start the inference server. Macros can be used. Comments allowed with |."
|
||||
},
|
||||
"cmdStop": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Command to run to stop the model gracefully. Uses ${PID} macro for upstream process id. If empty, default shutdown behavior is used."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"maxLength": 128,
|
||||
"description": "Display name for the model. Used in v1/models API response."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"maxLength": 1024,
|
||||
"description": "Description for the model. Used in v1/models API response."
|
||||
},
|
||||
"env": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-Z_][A-Z0-9_]*=.*$"
|
||||
},
|
||||
"default": [],
|
||||
"description": "Array of environment variables to inject into cmd's environment. Each value is a string in ENV_NAME=value format."
|
||||
},
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"default": "http://localhost:${PORT}",
|
||||
"format": "uri",
|
||||
"description": "URL where llama-swap routes API requests. If custom port is used in cmd, this must be set."
|
||||
},
|
||||
"aliases": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Alternative model names for this configuration. Must be unique globally."
|
||||
},
|
||||
"checkEndpoint": {
|
||||
"type": "string",
|
||||
"default": "/health",
|
||||
"pattern": "^/.*$|^none$",
|
||||
"description": "URL path to check if the server is ready. Use 'none' to skip health checking."
|
||||
},
|
||||
"ttl": {
|
||||
"type": "integer",
|
||||
"minimum": -1,
|
||||
"default": -1,
|
||||
"description": "Automatically unload the model after ttl seconds. -1 uses the global TTL value, 0 disables unloading. Must be >0 to enable."
|
||||
},
|
||||
"useModelName": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Override the model name sent to upstream server. Useful if upstream expects a different name."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
},
|
||||
"setParamsByID": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
},
|
||||
"default": {},
|
||||
"description": "Dictionary mapping requested model IDs (or aliases) to parameters to set/override in requests. Applied after setParams and can override those values. Useful with aliases to vary behaviour depending on which alias the client used (e.g. different reasoning_effort per alias). Keys support ${MODEL_ID} macro substitution. Protected params like 'model' cannot be overridden."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings. Supports stripParams, setParams, and setParamsByID."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of arbitrary values included in /v1/models. Can contain complex types. Only passed through in /v1/models responses."
|
||||
},
|
||||
"concurrencyLimit": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Overrides allowed number of active parallel requests to a model. 0 uses internal default of 10. >0 overrides default. Requests exceeding limit get HTTP 429."
|
||||
},
|
||||
"sendLoadingState": {
|
||||
"type": "boolean",
|
||||
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
|
||||
},
|
||||
"unlisted": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
|
||||
},
|
||||
"timeouts": {
|
||||
"$ref": "#/definitions/timeouts"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"groups": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
},
|
||||
"hooks": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"on_startup": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"preload": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": [],
|
||||
"description": "List of model IDs to load on startup. Model names must match keys in models. When preloading multiple models, define a group to prevent swapping."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Actions to perform on startup. Only supported action is preload."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||
},
|
||||
"logToStdout": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"proxy",
|
||||
"upstream",
|
||||
"both",
|
||||
"none"
|
||||
],
|
||||
"default": "proxy",
|
||||
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
|
||||
},
|
||||
"apiKeys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
|
||||
},
|
||||
"peers": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"proxy",
|
||||
"models"
|
||||
],
|
||||
"properties": {
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
|
||||
},
|
||||
"apiKey": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"description": "A list of models served by the peer."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
|
||||
},
|
||||
"timeouts": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connect": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP connection timeout in seconds."
|
||||
},
|
||||
"keepalive": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP keepalive connection timeout in seconds."
|
||||
},
|
||||
"responseHeader": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Time to wait for response headers in seconds."
|
||||
},
|
||||
"tlsHandshake": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 10,
|
||||
"description": "TLS handshake timeout in seconds."
|
||||
},
|
||||
"idleConn": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 90,
|
||||
"description": "Idle connection timeout in seconds."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections to this peer."
|
||||
}
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
@@ -23,12 +26,35 @@ healthCheckTimeout: 500
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||
# - optional, default: 10
|
||||
# - set to 0 to disable
|
||||
captureBuffer: 15
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
@@ -43,6 +69,17 @@ startPort: 10001
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# globalTTL: the default TTL in seconds before unloading a model
|
||||
# - optional, default: 0 (never automatically unload)
|
||||
# - must be >= 0
|
||||
globalTTL: 0
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
@@ -53,6 +90,9 @@ sendLoadingState: true
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||
# - env macros are substituted first, before regular macros
|
||||
# - if the env var is not set, config loading will fail with an error
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
@@ -65,6 +105,24 @@ macros:
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# Example of environment variable macros
|
||||
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||
# - useful for paths, secrets, or machine-specific configuration
|
||||
"models_dir": "${env.HOME}/models"
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
|
||||
# use environment variable macros to keep secrets out of the config
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
@@ -72,9 +130,8 @@ macros:
|
||||
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||
# - below are examples of the all the settings a model can have
|
||||
models:
|
||||
|
||||
# keys are the model names used in API requests
|
||||
"llama":
|
||||
"gpt-oss-120b":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
@@ -91,7 +148,7 @@ models:
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/llama-8B-Q4_K_M.gguf
|
||||
--model path/to/gpt-oss-120B.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
@@ -99,13 +156,13 @@ models:
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "llama 3.1 8B"
|
||||
name: "gpt-oss 120B"
|
||||
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A small but capable model used for quick testing"
|
||||
description: "A thinking model from OpenAI"
|
||||
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
@@ -120,14 +177,6 @@ models:
|
||||
# - if you use a custom port in cmd this *must* be set
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
@@ -136,8 +185,10 @@ models:
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: 0
|
||||
# - ttl values must be a value greater than 0
|
||||
# - optional, default: -1 (use global default)
|
||||
# - ttl values must be a value greater than or equal to 0
|
||||
# - a ttl of -1 will use the global TTL value as the default
|
||||
# - a ttl of 0 will mean never unload
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
@@ -145,11 +196,11 @@ models:
|
||||
# - optional, default: ""
|
||||
# - useful for when the upstream server expects a specific model name that
|
||||
# is different from the model's ID
|
||||
useModelName: "qwen:qwq"
|
||||
useModelName: "openai/gpt-oss-120B"
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - only stripParams is currently supported
|
||||
# - same capabilities as peer filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
@@ -159,6 +210,43 @@ models:
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for enforcing specific parameter values
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - always runs for the model
|
||||
setParams:
|
||||
# Example: enforce specific sampling parameters
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
|
||||
# setParamsByID: a dictionary of parameters to set based the model ID
|
||||
# - optional, default: empty dictionary
|
||||
# - combine with aliases to create variant behaviour without reloading the model
|
||||
# - parameters are set in the request body JSON
|
||||
# - run after setParams so it will override any settings
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - model aliases will be automatically created for each key
|
||||
setParamsByID:
|
||||
"${MODEL_ID}":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: medium
|
||||
"${MODEL_ID}:high":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: low
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
@@ -196,6 +284,22 @@ models:
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# timeouts: configure proxy connection timeouts for this model
|
||||
# - optional, defaults shown below
|
||||
# - useful for models running on slower hardware that need longer timeouts
|
||||
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
|
||||
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
|
||||
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
|
||||
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
|
||||
# - idleConn: idle connection timeout in seconds, default: 90 seconds
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
keepalive: 0
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
@@ -298,10 +402,74 @@ hooks:
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported action is preload
|
||||
on_startup:
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
# - can be a string or a macro
|
||||
apiKey: ${env.OPENROUTER_API_KEY}
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
# timeouts: configure proxy connection timeouts for this peer
|
||||
# - optional, defaults shown below
|
||||
# - useful when the peer runs on slower hardware
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
keepalive: 30
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# filters: a dictionary of filter settings for peer requests
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as model filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for removing parameters that the peer doesn't support
|
||||
# - the `model` parameter can never be removed
|
||||
stripParams: "temperature, top_p"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for injecting provider-specific settings like data retention policies
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce zero-data-retention for OpenRouter
|
||||
provider:
|
||||
data_collection: "deny"
|
||||
zdr: true
|
||||
|
||||
@@ -1,55 +1,164 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
# use this to test locally, example:
|
||||
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
|
||||
# you need read:package scope on the token. Generate a personal access token with
|
||||
# the scopes: gist, read:org, repo, write:packages
|
||||
# then: gh auth login (and copy/paste the new token)
|
||||
|
||||
LOG_DEBUG=${LOG_DEBUG:-0}
|
||||
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
|
||||
|
||||
log_debug() {
|
||||
if [ "$LOG_DEBUG" = "1" ]; then
|
||||
echo "[DEBUG] $*"
|
||||
fi
|
||||
}
|
||||
|
||||
log_info() {
|
||||
echo "[INFO] $*"
|
||||
}
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cuda13" "cpu" "rocm")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
|
||||
log_info "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
||||
# variable, this permits testing with forked llama.cpp repositories
|
||||
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
||||
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
|
||||
|
||||
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
||||
# to enable easy container builds on forked repos
|
||||
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
# Fetches the most recent llama.cpp tag matching the given prefix
|
||||
# Handles pagination to search beyond the first 100 results
|
||||
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
|
||||
# Returns: the version number extracted from the tag
|
||||
fetch_llama_tag() {
|
||||
local tag_prefix=$1
|
||||
local page=1
|
||||
local per_page=100
|
||||
|
||||
while true; do
|
||||
log_debug "Fetching page $page for tag prefix: $tag_prefix"
|
||||
|
||||
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
|
||||
|
||||
# Check for API errors
|
||||
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
|
||||
local error_msg=$(echo "$response" | jq -r '.message')
|
||||
log_info "GitHub API error: $error_msg"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Check if response is empty array (no more pages)
|
||||
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
|
||||
log_debug "No more pages (empty response)"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract matching tag from this page
|
||||
local found_tag=$(echo "$response" | jq -r \
|
||||
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
|
||||
| sort -r | head -n1)
|
||||
|
||||
if [ -n "$found_tag" ]; then
|
||||
log_debug "Found tag: $found_tag on page $page"
|
||||
echo "$found_tag" | awk -F '-' '{print $NF}'
|
||||
return 0
|
||||
fi
|
||||
|
||||
page=$((page + 1))
|
||||
|
||||
# Safety limit to prevent infinite loops
|
||||
if [ $page -gt 50 ]; then
|
||||
log_info "Reached pagination safety limit (50 pages)"
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the latest available
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
||||
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
LCPP_TAG=$(fetch_llama_tag "server")
|
||||
BASE_TAG=server-${LCPP_TAG}
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
|
||||
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||
fi
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
SD_TAG=master-${ARCH}
|
||||
|
||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
fi
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
log_info "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
else
|
||||
log_info "LCPP_TAG: $LCPP_TAG"
|
||||
fi
|
||||
|
||||
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
|
||||
log_info "Abort: DEBUG_ABORT_BUILD set"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for CONTAINER_TYPE in non-root root; do
|
||||
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
|
||||
USER_UID=0
|
||||
USER_GID=0
|
||||
USER_HOME=/root
|
||||
|
||||
if [ "$CONTAINER_TYPE" == "non-root" ]; then
|
||||
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
|
||||
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
|
||||
USER_UID=10001
|
||||
USER_GID=10001
|
||||
USER_HOME=/app
|
||||
fi
|
||||
|
||||
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||
|
||||
# For architectures with stable-diffusion.cpp support, layer sd-server on top
|
||||
case "$ARCH" in
|
||||
"musa" | "vulkan")
|
||||
log_info "Adding sd-server to $CONTAINER_TAG"
|
||||
docker build --provenance=false -f llama-swap-sd.Containerfile \
|
||||
--build-arg BASE=${CONTAINER_TAG} \
|
||||
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
|
||||
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
|
||||
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
|
||||
esac
|
||||
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
done
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Build script for llama-swap-docker with commit hash pinning
|
||||
#
|
||||
# Usage:
|
||||
# ./build-image.sh --cuda # Build CUDA image
|
||||
# ./build-image.sh --vulkan # Build Vulkan image
|
||||
# ./build-image.sh --cuda --no-cache # Build CUDA image without cache
|
||||
# LLAMA_COMMIT_HASH=abc123 ./build-image.sh --cuda # Override llama.cpp commit
|
||||
# LLAMA_COMMIT_HASH=b8429 ./build-image.sh --vulkan # Override llama.cpp release tag (vulkan uses prebuilt binaries)
|
||||
# WHISPER_COMMIT_HASH=def456 ./build-image.sh --vulkan # Override whisper.cpp commit
|
||||
# SD_COMMIT_HASH=ghi789 ./build-image.sh --cuda # Override stable-diffusion.cpp commit
|
||||
#
|
||||
# Features:
|
||||
# - Auto-detects latest commit hashes from git repos
|
||||
# - Builds llama-swap from local source code
|
||||
# - Allows environment variable overrides for reproducible builds
|
||||
# - Cache-friendly: changing commit hash busts cache appropriately
|
||||
# - Supports both CUDA and Vulkan backends (requires explicit flag)
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Parse command line arguments
|
||||
BACKEND=""
|
||||
NO_CACHE=false
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
echo ""
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
|
||||
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
|
||||
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
|
||||
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--cuda)
|
||||
BACKEND="cuda"
|
||||
;;
|
||||
--vulkan)
|
||||
BACKEND="vulkan"
|
||||
;;
|
||||
--no-cache)
|
||||
NO_CACHE=true
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
|
||||
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
|
||||
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
|
||||
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate backend selection
|
||||
if [[ -z "$BACKEND" ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Configuration
|
||||
if [[ -n "${DOCKER_IMAGE_TAG:-}" ]]; then
|
||||
# User provided a custom tag, use it as-is
|
||||
:
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
DOCKER_IMAGE_TAG="llama-swap:vulkan"
|
||||
else
|
||||
DOCKER_IMAGE_TAG="llama-swap:cuda"
|
||||
fi
|
||||
DOCKER_BUILDKIT="${DOCKER_BUILDKIT:-1}"
|
||||
|
||||
# Single unified Dockerfile, backend selected via build arg
|
||||
DOCKERFILE="Dockerfile"
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Building for: Vulkan (AMD GPUs and compatible hardware)"
|
||||
else
|
||||
echo "Building for: CUDA (NVIDIA GPUs)"
|
||||
fi
|
||||
|
||||
# Git repository URLs
|
||||
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
|
||||
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
|
||||
|
||||
# Function to get the latest commit hash from a git repo's default branch
|
||||
get_latest_commit() {
|
||||
local repo_url="$1"
|
||||
local branch="${2:-master}"
|
||||
|
||||
# Try to get the latest commit hash for the specified branch
|
||||
git ls-remote --heads "${repo_url}" "${branch}" 2>/dev/null | head -1 | cut -f1
|
||||
}
|
||||
|
||||
# Function to get the default branch name (master or main)
|
||||
get_default_branch() {
|
||||
local repo_url="$1"
|
||||
|
||||
# Check for master first
|
||||
if git ls-remote --heads "${repo_url}" master &>/dev/null; then
|
||||
echo "master"
|
||||
elif git ls-remote --heads "${repo_url}" main &>/dev/null; then
|
||||
echo "main"
|
||||
else
|
||||
echo "master" # fallback
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get the latest release tag from a GitHub repo
|
||||
get_latest_release_tag() {
|
||||
local owner_repo="$1"
|
||||
curl -fsSL "https://api.github.com/repos/${owner_repo}/releases/latest" \
|
||||
| grep '"tag_name"' | head -1 | cut -d'"' -f4
|
||||
}
|
||||
|
||||
echo "=========================================="
|
||||
echo "llama-swap-docker Build Script"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Determine commit hashes / release tags - use env vars or auto-detect
|
||||
# For vulkan builds, llama and sd use GitHub release tags (prebuilt binaries).
|
||||
# For cuda builds (or whisper on any backend), use git commit hashes.
|
||||
if [[ -n "${LLAMA_COMMIT_HASH:-}" ]]; then
|
||||
LLAMA_HASH="${LLAMA_COMMIT_HASH}"
|
||||
echo "llama.cpp: Using provided version: ${LLAMA_HASH}"
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
LLAMA_HASH=$(get_latest_release_tag "ggml-org/llama.cpp")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest release tag for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: Auto-detected latest release tag: ${LLAMA_HASH}"
|
||||
else
|
||||
LLAMA_BRANCH=$(get_default_branch "${LLAMA_REPO}")
|
||||
LLAMA_HASH=$(get_latest_commit "${LLAMA_REPO}" "${LLAMA_BRANCH}")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: Auto-detected latest commit (${LLAMA_BRANCH}): ${LLAMA_HASH}"
|
||||
fi
|
||||
|
||||
if [[ -n "${WHISPER_COMMIT_HASH:-}" ]]; then
|
||||
WHISPER_HASH="${WHISPER_COMMIT_HASH}"
|
||||
echo "whisper.cpp: Using provided commit hash: ${WHISPER_HASH}"
|
||||
else
|
||||
WHISPER_BRANCH=$(get_default_branch "${WHISPER_REPO}")
|
||||
WHISPER_HASH=$(get_latest_commit "${WHISPER_REPO}" "${WHISPER_BRANCH}")
|
||||
if [[ -z "${WHISPER_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "whisper.cpp: Auto-detected latest commit (${WHISPER_BRANCH}): ${WHISPER_HASH}"
|
||||
fi
|
||||
|
||||
if [[ -n "${SD_COMMIT_HASH:-}" ]]; then
|
||||
SD_HASH="${SD_COMMIT_HASH}"
|
||||
echo "stable-diffusion.cpp: Using provided version: ${SD_HASH}"
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
SD_HASH=$(get_latest_release_tag "leejet/stable-diffusion.cpp")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest release tag for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: Auto-detected latest release tag: ${SD_HASH}"
|
||||
else
|
||||
SD_BRANCH=$(get_default_branch "${SD_REPO}")
|
||||
SD_HASH=$(get_latest_commit "${SD_REPO}" "${SD_BRANCH}")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: Auto-detected latest commit (${SD_BRANCH}): ${SD_HASH}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting Docker build..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Build the Docker image with commit hashes as build args
|
||||
# Build context is the repository root (..) so the Dockerfile can access Go source
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
BUILD_ARGS=(
|
||||
--build-arg "BACKEND=${BACKEND}"
|
||||
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
|
||||
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
|
||||
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
|
||||
-t "${DOCKER_IMAGE_TAG}"
|
||||
-f "${SCRIPT_DIR}/${DOCKERFILE}"
|
||||
)
|
||||
|
||||
if [[ "$NO_CACHE" == true ]]; then
|
||||
BUILD_ARGS+=(--no-cache)
|
||||
echo "Note: Building without cache"
|
||||
fi
|
||||
|
||||
# Use docker buildx with a custom builder for parallelism control
|
||||
# The legacy DOCKER_BUILDKIT=1 docker build doesn't respect BUILDKIT_MAX_PARALLELISM env var
|
||||
# We need to use a custom builder with a buildkitd.toml config file
|
||||
BUILDER_NAME="llama-swap-builder"
|
||||
|
||||
# Check if our custom builder exists with the right config, create/update if needed
|
||||
if ! docker buildx inspect "$BUILDER_NAME" >/dev/null 2>&1; then
|
||||
echo "Creating custom buildx builder with max-parallelism=1..."
|
||||
|
||||
# Create buildkitd.toml config file
|
||||
cat > buildkitd.toml << 'BUILDKIT_EOF'
|
||||
[worker.oci]
|
||||
max-parallelism = 1
|
||||
BUILDKIT_EOF
|
||||
|
||||
# Create the builder with the config
|
||||
docker buildx create --name "$BUILDER_NAME" \
|
||||
--driver docker-container \
|
||||
--buildkitd-config buildkitd.toml \
|
||||
--use
|
||||
else
|
||||
# Switch to our builder
|
||||
docker buildx use "$BUILDER_NAME"
|
||||
fi
|
||||
|
||||
echo "Building with sequential stages (one at a time), each using all CPU cores..."
|
||||
echo "Using builder: $BUILDER_NAME"
|
||||
|
||||
# Use docker buildx build with --load to load the image into Docker
|
||||
# The --builder flag ensures we use our custom builder with max-parallelism=1
|
||||
# Build context is the repository root so we can access Go source files
|
||||
docker buildx build --builder "$BUILDER_NAME" --load "${BUILD_ARGS[@]}" "${REPO_ROOT}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verifying build artifacts..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Verify all expected binaries exist in the image
|
||||
MISSING_BINARIES=()
|
||||
|
||||
for binary in llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap; do
|
||||
if ! docker run --rm "${DOCKER_IMAGE_TAG}" which "${binary}" >/dev/null 2>&1; then
|
||||
MISSING_BINARIES+=("${binary}")
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
|
||||
echo "ERROR: Build succeeded but the following binaries are missing from the image:"
|
||||
for binary in "${MISSING_BINARIES[@]}"; do
|
||||
echo " - ${binary}"
|
||||
done
|
||||
echo ""
|
||||
echo "This usually indicates a build stage failure. Try running with --no-cache flag:"
|
||||
echo " ./build-image.sh --vulkan --no-cache"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All expected binaries verified: llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Build complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Image tag: ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Built with:"
|
||||
echo " llama.cpp: ${LLAMA_HASH}"
|
||||
echo " whisper.cpp: ${WHISPER_HASH}"
|
||||
echo " stable-diffusion.cpp: ${SD_HASH}"
|
||||
echo " llama-swap: $(docker run --rm "${DOCKER_IMAGE_TAG}" cat /versions.txt | grep llama-swap | cut -d' ' -f2-)"
|
||||
echo ""
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Note: For AMD GPUs, you may also need to mount render devices:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
|
||||
else
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
|
||||
fi
|
||||
@@ -15,4 +15,19 @@ models:
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
--port 9999
|
||||
|
||||
z-image:
|
||||
checkEndpoint: /
|
||||
cmd: |
|
||||
/app/sd-server
|
||||
--listen-port 9999
|
||||
--diffusion-fa
|
||||
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||
--vae /models/ae.safetensors
|
||||
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||
--offload-to-cpu
|
||||
--cfg-scale 1.0
|
||||
--height 512 --width 512
|
||||
--steps 8
|
||||
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
|
||||
@@ -0,0 +1,11 @@
|
||||
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
|
||||
ARG SD_TAG=master-vulkan
|
||||
ARG BASE=llama-swap:latest
|
||||
|
||||
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
|
||||
FROM ${BASE}
|
||||
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
|
||||
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
|
||||
@@ -1,16 +1,44 @@
|
||||
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=89
|
||||
ARG LS_VER=170
|
||||
ARG LS_REPO=mostlygeek/llama-swap
|
||||
|
||||
# Set default UID/GID arguments
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
ARG USER_HOME=/app
|
||||
|
||||
# Add user/group
|
||||
ENV HOME=$USER_HOME
|
||||
RUN if [ $UID -ne 0 ]; then \
|
||||
if [ $GID -ne 0 ]; then \
|
||||
groupadd --system --gid $GID app; \
|
||||
fi; \
|
||||
useradd --system --uid $UID --gid $GID \
|
||||
--home $USER_HOME app; \
|
||||
fi
|
||||
|
||||
# Handle paths
|
||||
RUN mkdir --parents $HOME /app
|
||||
RUN chown --recursive $UID:$GID $HOME /app
|
||||
|
||||
# Switch user
|
||||
USER $UID:$GID
|
||||
|
||||
WORKDIR /app
|
||||
RUN \
|
||||
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||
|
||||
COPY config.example.yaml /app/config.yaml
|
||||
# Add /app to PATH
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
RUN \
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
|
||||
|
||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
# Unified multi-stage Dockerfile for AI inference tools
|
||||
# Supports CUDA and Vulkan backends via BACKEND build arg
|
||||
#
|
||||
# Usage:
|
||||
# docker buildx build --build-arg BACKEND=cuda -t llama-swap:unified-cuda .
|
||||
# docker buildx build --build-arg BACKEND=vulkan -t llama-swap:unified-vulkan .
|
||||
# docker buildx build --build-arg BACKEND=cuda --build-arg CMAKE_CUDA_ARCHITECTURES="86;89" -t llama-swap:unified-cuda .
|
||||
#
|
||||
# Each project has its own install script that handles cloning, building,
|
||||
# and installing binaries. Build stages are independent for cache efficiency.
|
||||
|
||||
ARG BACKEND=cuda
|
||||
|
||||
# ── Builder bases ──────────────────────────────────────────────────────
|
||||
|
||||
FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 AS builder-base-cuda
|
||||
|
||||
ARG CMAKE_CUDA_ARCHITECTURES="60;61;75;86;89"
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}
|
||||
ENV CCACHE_DIR=/ccache
|
||||
ENV CCACHE_MAXSIZE=2G
|
||||
ENV PATH="/usr/lib/ccache:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git python3 python3-pip libssl-dev \
|
||||
curl ca-certificates ccache make wget \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# ──
|
||||
|
||||
FROM ubuntu:24.04 AS builder-base-vulkan
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV CCACHE_DIR=/ccache
|
||||
ENV CCACHE_MAXSIZE=2G
|
||||
ENV PATH="/usr/lib/ccache:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git python3 python3-pip libssl-dev \
|
||||
curl ca-certificates ccache make wget software-properties-common \
|
||||
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# ── Select builder base by BACKEND ────────────────────────────────────
|
||||
|
||||
FROM builder-base-${BACKEND} AS builder-base
|
||||
|
||||
# ── Build whisper.cpp (fastest build, run first) ──────────────────────
|
||||
|
||||
FROM builder-base AS whisper-build
|
||||
ARG BACKEND=cuda
|
||||
ARG WHISPER_COMMIT_HASH=master
|
||||
COPY install-whisper.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=whisper-${BACKEND},target=/src/whisper.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-whisper.sh "${WHISPER_COMMIT_HASH}"
|
||||
|
||||
# ── Build stable-diffusion.cpp ────────────────────────────────────────
|
||||
|
||||
FROM builder-base AS sd-build
|
||||
ARG BACKEND=cuda
|
||||
ARG SD_COMMIT_HASH=master
|
||||
COPY install-sd.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=sd-${BACKEND},target=/src/stable-diffusion.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-sd.sh "${SD_COMMIT_HASH}"
|
||||
|
||||
# ── Build llama.cpp (slowest build, run last) ─────────────────────────
|
||||
|
||||
FROM builder-base AS llama-build
|
||||
ARG BACKEND=cuda
|
||||
ARG LLAMA_COMMIT_HASH=master
|
||||
COPY install-llama.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=llama-${BACKEND},target=/src/llama.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-llama.sh "${LLAMA_COMMIT_HASH}"
|
||||
|
||||
# ── Build ik_llama.cpp (CUDA only) ────────────────────────────────────
|
||||
#
|
||||
# Two named stages allow ARG BACKEND to select at build time:
|
||||
# - ik-llama-cuda : real build (from builder-base-cuda)
|
||||
# - ik-llama-vulkan: no-op (empty /install/bin, skips CUDA pull entirely)
|
||||
# BuildKit only evaluates the selected branch, so vulkan builds never
|
||||
# pull nvidia/cuda:*-devel or compile ik_llama.cpp.
|
||||
|
||||
FROM builder-base-vulkan AS ik-llama-vulkan
|
||||
RUN mkdir -p /install/bin
|
||||
|
||||
FROM builder-base-cuda AS ik-llama-cuda
|
||||
ARG IK_LLAMA_COMMIT_HASH=main
|
||||
COPY install-ik-llama.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-cuda,target=/ccache \
|
||||
--mount=type=cache,id=ik-llama-cuda,target=/src/ik_llama.cpp/build \
|
||||
bash /build/install-ik-llama.sh "${IK_LLAMA_COMMIT_HASH}"
|
||||
|
||||
ARG BACKEND=cuda
|
||||
FROM ik-llama-${BACKEND} AS ik-llama-build
|
||||
|
||||
# ── Download llama-swap release binary ────────────────────────────────
|
||||
|
||||
FROM builder-base AS llama-swap-download
|
||||
ARG LS_VERSION=latest
|
||||
COPY install-llama-swap.sh /build/
|
||||
RUN bash /build/install-llama-swap.sh "${LS_VERSION}"
|
||||
|
||||
# ── Runtime bases ─────────────────────────────────────────────────────
|
||||
|
||||
FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04 AS runtime-cuda
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
|
||||
ENV PATH="/usr/local/bin:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgomp1 python3 curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# CUDA stub drivers for container compatibility
|
||||
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so
|
||||
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
|
||||
|
||||
# ──
|
||||
|
||||
FROM ubuntu:24.04 AS runtime-vulkan
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH="/usr/local/bin:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgomp1 libvulkan1 mesa-vulkan-drivers \
|
||||
python3 curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ── Select runtime base by BACKEND ────────────────────────────────────
|
||||
|
||||
FROM runtime-${BACKEND} AS runtime
|
||||
|
||||
ARG BACKEND=cuda
|
||||
ARG LLAMA_COMMIT_HASH=unknown
|
||||
ARG WHISPER_COMMIT_HASH=unknown
|
||||
ARG SD_COMMIT_HASH=unknown
|
||||
ARG IK_LLAMA_COMMIT_HASH=unknown
|
||||
ARG RUN_UID=0
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3-numpy python3-sentencepiece \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user when RUN_UID != 0
|
||||
RUN if [ "$RUN_UID" != "0" ]; then \
|
||||
groupadd --system --gid $RUN_UID llama-swap && \
|
||||
useradd --system --uid $RUN_UID --gid $RUN_UID \
|
||||
--home /app --shell /sbin/nologin llama-swap; \
|
||||
fi && \
|
||||
mkdir -p /etc/llama-swap/config && \
|
||||
chown -R ${RUN_UID}:${RUN_UID} /etc/llama-swap
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy whisper.cpp binaries and libraries
|
||||
COPY --from=whisper-build /install/bin/whisper-server /usr/local/bin/
|
||||
COPY --from=whisper-build /install/bin/whisper-cli /usr/local/bin/
|
||||
COPY --from=whisper-build /install/lib/ /usr/local/lib/
|
||||
|
||||
# Copy stable-diffusion.cpp binaries and libraries
|
||||
COPY --from=sd-build /install/bin/sd-server /usr/local/bin/
|
||||
COPY --from=sd-build /install/bin/sd-cli /usr/local/bin/
|
||||
COPY --from=sd-build /install/lib/ /usr/local/lib/
|
||||
|
||||
# Copy llama.cpp binaries (statically linked)
|
||||
COPY --from=llama-build /install/bin/llama-server /usr/local/bin/
|
||||
COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
|
||||
|
||||
# Copy ik-llama-server (CUDA only; empty copy for vulkan)
|
||||
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
|
||||
|
||||
# Copy llama-swap binary
|
||||
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
|
||||
COPY --from=llama-swap-download /install/llama-swap-version /tmp/
|
||||
|
||||
RUN ldconfig
|
||||
|
||||
COPY config.example.yaml /etc/llama-swap/config/config.yaml
|
||||
|
||||
# Version tracking
|
||||
RUN echo "llama.cpp: ${LLAMA_COMMIT_HASH}" > /versions.txt && \
|
||||
echo "whisper.cpp: ${WHISPER_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "stable-diffusion.cpp: ${SD_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "ik_llama.cpp: ${IK_LLAMA_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "llama-swap: $(cat /tmp/llama-swap-version)" >> /versions.txt && \
|
||||
echo "backend: ${BACKEND}" >> /versions.txt && \
|
||||
echo "build_timestamp: $(date -u +%Y-%m-%dT%H:%M:%SZ)" >> /versions.txt
|
||||
|
||||
RUN mkdir -p /models && chown ${RUN_UID}:${RUN_UID} /models
|
||||
WORKDIR /models
|
||||
USER ${RUN_UID}
|
||||
ENTRYPOINT ["llama-swap"]
|
||||
CMD ["-config", "/etc/llama-swap/config/config.yaml", "-listen", "0.0.0.0:8080"]
|
||||
@@ -0,0 +1,8 @@
|
||||
# Unified Docker Container
|
||||
|
||||
These scripts create a custom llama-swap container that contains:
|
||||
|
||||
- llama-server for LLMs, rerank and embedding model support
|
||||
- sd-server (stable-diffusion.cpp) for image generation
|
||||
- whisper.cpp for ASR
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Build script for unified container with version pinning
|
||||
#
|
||||
# Usage:
|
||||
# ./build-image.sh --cuda # Build CUDA image
|
||||
# ./build-image.sh --vulkan # Build Vulkan image
|
||||
# ./build-image.sh --cuda --no-cache # Build without cache
|
||||
# LLAMA_REF=b1234 ./build-image.sh --vulkan # Pin llama.cpp to a commit hash
|
||||
# LLAMA_REF=v1.2.3 ./build-image.sh --cuda # Pin llama.cpp to a tag
|
||||
# WHISPER_REF=v1.0.0 ./build-image.sh --vulkan # Pin whisper.cpp to a tag
|
||||
# SD_REF=master ./build-image.sh --cuda # Pin stable-diffusion.cpp to a branch
|
||||
# LS_VERSION=170 ./build-image.sh --cuda # Override llama-swap version
|
||||
# IK_LLAMA_REF=main ./build-image.sh --cuda # Pin ik_llama.cpp to main branch (CUDA only)
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BACKEND=""
|
||||
NO_CACHE=false
|
||||
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--cuda)
|
||||
BACKEND="cuda"
|
||||
;;
|
||||
--vulkan)
|
||||
BACKEND="vulkan"
|
||||
;;
|
||||
--no-cache)
|
||||
NO_CACHE=true
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:unified-cuda or llama-swap:unified-vulkan)"
|
||||
echo " LLAMA_REF Pin llama.cpp to a commit, tag, or branch"
|
||||
echo " WHISPER_REF Pin whisper.cpp to a commit, tag, or branch"
|
||||
echo " SD_REF Pin stable-diffusion.cpp to a commit, tag, or branch"
|
||||
echo " IK_LLAMA_REF Pin ik_llama.cpp to a commit, tag, or branch (CUDA only)"
|
||||
echo " LS_VERSION Override llama-swap version (e.g., '170' or 'latest')"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$BACKEND" ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
echo ""
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DOCKER_IMAGE_TAG="${DOCKER_IMAGE_TAG:-llama-swap:unified-${BACKEND}}"
|
||||
|
||||
# Git repository URLs
|
||||
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
|
||||
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
|
||||
LLAMA_SWAP_REPO="https://github.com/mostlygeek/llama-swap.git"
|
||||
IK_LLAMA_REPO="https://github.com/ikawrakow/ik_llama.cpp.git"
|
||||
|
||||
# Resolve a git ref (commit hash, tag, or branch) to a full commit hash.
|
||||
# Requires only: git, network access to the remote.
|
||||
resolve_ref() {
|
||||
local repo_url="$1"
|
||||
local ref="$2"
|
||||
|
||||
# Full 40-char SHA — use as-is
|
||||
if [[ "${ref}" =~ ^[0-9a-f]{40}$ ]]; then
|
||||
echo "${ref}"
|
||||
return
|
||||
fi
|
||||
|
||||
# Try tag then branch (exact match)
|
||||
local hash
|
||||
hash=$(git ls-remote "${repo_url}" "refs/tags/${ref}" "refs/heads/${ref}" 2>/dev/null | head -1 | cut -f1)
|
||||
if [[ -n "${hash}" ]]; then
|
||||
echo "${hash}"
|
||||
return
|
||||
fi
|
||||
|
||||
# Short hash (7+ chars): scan all refs for a SHA with this prefix
|
||||
if [[ "${ref}" =~ ^[0-9a-f]{7,}$ ]]; then
|
||||
hash=$(git ls-remote "${repo_url}" 2>/dev/null | grep "^${ref}" | head -1 | cut -f1)
|
||||
if [[ -n "${hash}" ]]; then
|
||||
echo "${hash}"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "ERROR: Could not resolve ref '${ref}' for ${repo_url}" >&2
|
||||
if [[ "${ref}" =~ ^[0-9a-f]+$ && ${#ref} -lt 7 ]]; then
|
||||
echo " Short hashes must be at least 7 characters (got ${#ref})." >&2
|
||||
else
|
||||
echo " Tried: tag, branch, git ls-remote prefix match" >&2
|
||||
fi
|
||||
echo " Use a full 40-char SHA, a tag name, a branch name, or a 7+ char short hash." >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Resolve HEAD of a repo without needing to know the default branch name.
|
||||
get_latest_hash() {
|
||||
git ls-remote "${1}" HEAD 2>/dev/null | head -1 | cut -f1
|
||||
}
|
||||
|
||||
echo "=========================================="
|
||||
echo "llama-swap Unified Build (${BACKEND})"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Resolve llama.cpp ref
|
||||
if [[ -n "${LLAMA_REF:-}" ]]; then
|
||||
LLAMA_HASH=$(resolve_ref "${LLAMA_REPO}" "${LLAMA_REF}") || exit 1
|
||||
echo "llama.cpp: ${LLAMA_REF} -> ${LLAMA_HASH}"
|
||||
else
|
||||
LLAMA_HASH=$(get_latest_hash "${LLAMA_REPO}")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: latest HEAD: ${LLAMA_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve whisper.cpp ref
|
||||
if [[ -n "${WHISPER_REF:-}" ]]; then
|
||||
WHISPER_HASH=$(resolve_ref "${WHISPER_REPO}" "${WHISPER_REF}") || exit 1
|
||||
echo "whisper.cpp: ${WHISPER_REF} -> ${WHISPER_HASH}"
|
||||
else
|
||||
WHISPER_HASH=$(get_latest_hash "${WHISPER_REPO}")
|
||||
if [[ -z "${WHISPER_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "whisper.cpp: latest HEAD: ${WHISPER_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve stable-diffusion.cpp ref
|
||||
if [[ -n "${SD_REF:-}" ]]; then
|
||||
SD_HASH=$(resolve_ref "${SD_REPO}" "${SD_REF}") || exit 1
|
||||
echo "stable-diffusion.cpp: ${SD_REF} -> ${SD_HASH}"
|
||||
else
|
||||
SD_HASH=$(get_latest_hash "${SD_REPO}")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: latest HEAD: ${SD_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve ik_llama.cpp ref (CUDA only)
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
if [[ -n "${IK_LLAMA_REF:-}" ]]; then
|
||||
IK_LLAMA_HASH=$(resolve_ref "${IK_LLAMA_REPO}" "${IK_LLAMA_REF}") || exit 1
|
||||
echo "ik_llama.cpp: ${IK_LLAMA_REF} -> ${IK_LLAMA_HASH}"
|
||||
else
|
||||
IK_LLAMA_HASH=$(get_latest_hash "${IK_LLAMA_REPO}")
|
||||
if [[ -z "${IK_LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for ik_llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "ik_llama.cpp: latest HEAD: ${IK_LLAMA_HASH}"
|
||||
fi
|
||||
else
|
||||
IK_LLAMA_HASH="n/a"
|
||||
echo "ik_llama.cpp: skipped (vulkan build)"
|
||||
fi
|
||||
|
||||
# Resolve llama-swap ref
|
||||
if [[ -n "${LS_VERSION:-}" ]]; then
|
||||
LS_HASH=$(resolve_ref "${LLAMA_SWAP_REPO}" "${LS_VERSION}") || exit 1
|
||||
echo "llama-swap: ${LS_VERSION} -> ${LS_HASH}"
|
||||
else
|
||||
LS_HASH=$(get_latest_hash "${LLAMA_SWAP_REPO}")
|
||||
if [[ -z "${LS_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama-swap" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama-swap: latest HEAD: ${LS_HASH}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting Docker build..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
BUILD_ARGS=(
|
||||
--build-arg "BACKEND=${BACKEND}"
|
||||
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
|
||||
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
|
||||
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
|
||||
--build-arg "IK_LLAMA_COMMIT_HASH=${IK_LLAMA_HASH}"
|
||||
--build-arg "LS_VERSION=${LS_HASH}"
|
||||
-t "${DOCKER_IMAGE_TAG}"
|
||||
-f "${SCRIPT_DIR}/Dockerfile"
|
||||
)
|
||||
|
||||
if [[ "$NO_CACHE" == true ]]; then
|
||||
BUILD_ARGS+=(--no-cache)
|
||||
echo "Note: Building without cache"
|
||||
elif [[ "${GITHUB_ACTIONS:-}" == "true" && "${ACT:-}" != "true" ]]; then
|
||||
CACHE_REF="ghcr.io/mostlygeek/llama-swap:unified-${BACKEND}-cache"
|
||||
BUILD_ARGS+=(
|
||||
--cache-from "type=registry,ref=${CACHE_REF}"
|
||||
--cache-to "type=registry,ref=${CACHE_REF},mode=max"
|
||||
)
|
||||
echo "Note: Using registry cache (${CACHE_REF})"
|
||||
fi
|
||||
|
||||
DOCKER_BUILDKIT=1 docker buildx build --load "${BUILD_ARGS[@]}" "${SCRIPT_DIR}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verifying build artifacts..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
EXPECTED_BINARIES=(llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap)
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
EXPECTED_BINARIES+=(ik-llama-server)
|
||||
fi
|
||||
|
||||
MISSING_BINARIES=()
|
||||
for binary in "${EXPECTED_BINARIES[@]}"; do
|
||||
if ! docker run --rm --entrypoint which "${DOCKER_IMAGE_TAG}" "${binary}" >/dev/null 2>&1; then
|
||||
MISSING_BINARIES+=("${binary}")
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
|
||||
echo "ERROR: Build succeeded but the following binaries are missing:"
|
||||
for binary in "${MISSING_BINARIES[@]}"; do
|
||||
echo " - ${binary}"
|
||||
done
|
||||
echo ""
|
||||
echo "Try running with --no-cache flag:"
|
||||
echo " ./build-image.sh --${BACKEND} --no-cache"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
VERIFIED_LIST="llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
VERIFIED_LIST="${VERIFIED_LIST}, ik-llama-server"
|
||||
fi
|
||||
echo "All expected binaries verified: ${VERIFIED_LIST}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Building rootless image..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
ROOTLESS_TAG="${DOCKER_IMAGE_TAG}-rootless"
|
||||
docker buildx build --load -t "${ROOTLESS_TAG}" - <<EOF
|
||||
FROM ${DOCKER_IMAGE_TAG}
|
||||
USER root
|
||||
RUN groupadd --system --gid 10001 llama-swap && \\
|
||||
useradd --system --uid 10001 --gid 10001 \\
|
||||
--home /app --shell /sbin/nologin llama-swap && \\
|
||||
chown -R 10001:10001 /etc/llama-swap /models
|
||||
USER 10001
|
||||
EOF
|
||||
|
||||
echo "Rootless image built: ${ROOTLESS_TAG}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Build complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Image tags:"
|
||||
echo " ${DOCKER_IMAGE_TAG}"
|
||||
echo " ${ROOTLESS_TAG}"
|
||||
echo ""
|
||||
echo "Built with:"
|
||||
echo " llama.cpp: ${LLAMA_HASH}"
|
||||
echo " whisper.cpp: ${WHISPER_HASH}"
|
||||
echo " stable-diffusion.cpp: ${SD_HASH}"
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
echo " ik_llama.cpp: ${IK_LLAMA_HASH}"
|
||||
fi
|
||||
echo " llama-swap: $(docker run --rm --entrypoint cat "${DOCKER_IMAGE_TAG}" /versions.txt | grep llama-swap | cut -d' ' -f2-)"
|
||||
echo ""
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Note: For AMD GPUs, you may also need:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
|
||||
else
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
|
||||
fi
|
||||
@@ -0,0 +1,33 @@
|
||||
# placeholder example configuration
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port ${PORT}
|
||||
|
||||
"whisper":
|
||||
checkEndpoint: /v1/audio/transcriptions/
|
||||
cmd: >
|
||||
whisper-server
|
||||
--port ${PORT}
|
||||
--m /models/whisper.bin
|
||||
--flash-attn
|
||||
--request-path /v1/audio/transcriptions --inference-path ""
|
||||
|
||||
"image":
|
||||
checkEndpoint: /
|
||||
cmd: |
|
||||
/app/sd-server
|
||||
--listen-port 9999
|
||||
--diffusion-fa
|
||||
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||
--vae /models/ae.safetensors
|
||||
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||
--offload-to-cpu
|
||||
--cfg-scale 1.0
|
||||
--height 512 --width 512
|
||||
--steps 8
|
||||
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Install ik_llama.cpp - clone, build, and install binaries
|
||||
# Usage: ./install-ik-llama.sh <commit_hash>
|
||||
# Note: CUDA only; always built against builder-base-cuda
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-main}"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted build dir doesn't break clone)
|
||||
echo "=== Cloning ik_llama.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/ik_llama.cpp
|
||||
cd /src/ik_llama.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ikawrakow/ik_llama.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DGGML_CUDA=ON
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda -Wl,--allow-shlib-undefined"
|
||||
)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building ik_llama.cpp ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target llama-server
|
||||
|
||||
if [ ! -f "build/bin/llama-server" ]; then
|
||||
echo "FATAL: llama-server not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install as ik-llama-server to avoid collision with llama.cpp's llama-server
|
||||
cp "build/bin/llama-server" "/install/bin/ik-llama-server"
|
||||
echo "=== ik_llama.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
# Install llama-swap - download latest release binary from GitHub
|
||||
# Usage: ./install-llama-swap.sh [version]
|
||||
# version: release version number (e.g., "170") or "latest" (default)
|
||||
set -e
|
||||
|
||||
VERSION="${1:-latest}"
|
||||
REPO="mostlygeek/llama-swap"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# If a full commit hash is given, find the release tag that points to it
|
||||
if echo "${VERSION}" | grep -qE '^[0-9a-f]{40}$'; then
|
||||
echo "=== Resolving commit ${VERSION:0:7} to release tag ==="
|
||||
TAG=$(git ls-remote --tags "https://github.com/${REPO}.git" 2>/dev/null \
|
||||
| grep "^${VERSION}" | sed 's|.*refs/tags/||' | grep -v '\^{}' | head -1)
|
||||
if [ -n "${TAG}" ]; then
|
||||
echo "Resolved to tag: ${TAG}"
|
||||
VERSION="${TAG#v}"
|
||||
else
|
||||
echo "No release tag found for commit ${VERSION:0:7}, using latest"
|
||||
VERSION="latest"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Strip leading 'v' prefix so both "198" and "v198" work
|
||||
VERSION="${VERSION#v}"
|
||||
|
||||
# Resolve "latest" to actual version number
|
||||
if [ "$VERSION" = "latest" ]; then
|
||||
echo "=== Resolving latest llama-swap release ==="
|
||||
VERSION=$(curl -fsSL "https://api.github.com/repos/${REPO}/releases/latest" \
|
||||
| grep '"tag_name"' | head -1 | cut -d'"' -f4 | sed 's/^v//')
|
||||
if [ -z "$VERSION" ]; then
|
||||
echo "FATAL: Could not determine latest release version" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Latest version: ${VERSION}"
|
||||
fi
|
||||
|
||||
# Download and extract
|
||||
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_amd64.tar.gz"
|
||||
echo "=== Downloading llama-swap v${VERSION} ==="
|
||||
echo "URL: $URL"
|
||||
curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
|
||||
tar -xzf /tmp/llama-swap.tar.gz -C /install/bin/
|
||||
rm /tmp/llama-swap.tar.gz
|
||||
|
||||
# Validate
|
||||
if [ ! -x "/install/bin/llama-swap" ]; then
|
||||
echo "FATAL: llama-swap binary not found or not executable" >&2
|
||||
ls -la /install/bin/ >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "$VERSION" > /install/llama-swap-version
|
||||
|
||||
echo "=== llama-swap v${VERSION} installed ==="
|
||||
ls -la /install/bin/llama-swap
|
||||
@@ -0,0 +1,63 @@
|
||||
#!/bin/bash
|
||||
# Install llama.cpp - clone, build, and install binaries
|
||||
# Usage: BACKEND=cuda|vulkan ./install-llama.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/llama.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning llama.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/llama.cpp
|
||||
cd /src/llama.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ggml-org/llama.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DLLAMA_BUILD_TESTS=OFF
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(llama-cli llama-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building llama.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in "${TARGETS[@]}"; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
echo "=== llama.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
# Install stable-diffusion.cpp - clone, build, and install binaries and library
|
||||
# Usage: BACKEND=cuda|vulkan ./install-sd.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin /install/lib
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/stable-diffusion.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning stable-diffusion.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/stable-diffusion.cpp
|
||||
cd /src/stable-diffusion.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/leejet/stable-diffusion.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
git submodule update --init --recursive --depth=1
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DSD_BUILD_EXAMPLES=ON
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
-DSD_CUDA=ON
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
-DSD_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(stable-diffusion sd-cli sd-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building stable-diffusion.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in sd-cli sd-server; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
|
||||
|
||||
echo "=== stable-diffusion.cpp build complete ==="
|
||||
ls -la /install/bin/ /install/lib/
|
||||
@@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
# Install whisper.cpp - clone, build, and install binaries
|
||||
# Usage: BACKEND=cuda|vulkan ./install-whisper.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin /install/lib
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/whisper.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning whisper.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/whisper.cpp
|
||||
cd /src/whisper.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ggml-org/whisper.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(whisper-cli whisper-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building whisper.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in "${TARGETS[@]}"; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
|
||||
|
||||
echo "=== whisper.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
|
Before Width: | Height: | Size: 261 KiB After Width: | Height: | Size: 261 KiB |
|
Before Width: | Height: | Size: 351 KiB After Width: | Height: | Size: 351 KiB |
|
After Width: | Height: | Size: 198 KiB |
@@ -86,9 +86,12 @@ llama-swap supports many more features to customize how you want to manage your
|
||||
## Full Configuration Example
|
||||
|
||||
> [!NOTE]
|
||||
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples.
|
||||
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
|
||||
|
||||
```yaml
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
@@ -114,6 +117,24 @@ healthCheckTimeout: 500
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
@@ -126,6 +147,30 @@ metricsMaxInMemory: 1000
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||
# field
|
||||
# - optional, default: false
|
||||
# - when true, a stream of loading messages will be sent to the client in the
|
||||
# reasoning field so chat UIs can show that loading is in progress.
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
@@ -274,6 +319,33 @@ models:
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# timeouts: configure proxy connection timeouts for this model
|
||||
# - optional, defaults shown below
|
||||
# - useful for models on slower hardware that need longer timeouts
|
||||
# - increase responseHeader to avoid "timeout awaiting response headers" errors
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
# connect: TCP connection timeout in seconds
|
||||
# - default: 30
|
||||
connect: 30
|
||||
|
||||
# responseHeader: time to wait for response headers in seconds
|
||||
# - default: 60
|
||||
# - for slow image generation or large models, consider increasing to 300+ seconds
|
||||
responseHeader: 60
|
||||
|
||||
# tlsHandshake: TLS handshake timeout in seconds
|
||||
# - default: 10
|
||||
tlsHandshake: 10
|
||||
|
||||
# idleConn: idle connection timeout in seconds
|
||||
# - default: 90
|
||||
idleConn: 90
|
||||
|
||||
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
@@ -383,4 +455,47 @@ hooks:
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
|
||||
# timeouts: configure proxy connection timeouts for this peer
|
||||
# - optional, defaults shown below
|
||||
# - useful when the peer runs on slower hardware
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
apiKey: sk-your-openrouter-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
```
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
## Container Security
|
||||
|
||||
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
|
||||
|
||||
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
|
||||
|
||||
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
|
||||
|
||||
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/mostlygeek/llama-swap
|
||||
|
||||
go 1.23.0
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
@@ -37,9 +37,9 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -80,16 +80,16 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
@@ -95,7 +95,9 @@ func main() {
|
||||
|
||||
fmt.Println("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
srv.Handler = proxy.New(conf)
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
srv.Handler = newPM
|
||||
fmt.Println("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
@@ -110,7 +112,9 @@ func main() {
|
||||
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srv.Handler = proxy.New(conf)
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
srv.Handler = newPM
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
const (
|
||||
LogToStdoutProxy = "proxy"
|
||||
LogToStdoutUpstream = "upstream"
|
||||
LogToStdoutBoth = "both"
|
||||
LogToStdoutNone = "none"
|
||||
)
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
@@ -81,6 +87,7 @@ type GroupConfig struct {
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
@@ -113,7 +120,11 @@ type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||
LogToStdout string `yaml:"logToStdout"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
CaptureBuffer int `yaml:"captureBuffer"`
|
||||
GlobalTTL int `yaml:"globalTTL"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
@@ -132,6 +143,15 @@ type Config struct {
|
||||
|
||||
// send loading state in reasoning
|
||||
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
|
||||
// support API keys, see issue #433, #50, #251
|
||||
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -166,21 +186,31 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// default configuration values
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
@@ -188,6 +218,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
@@ -199,55 +239,55 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT, MODEL_ID
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// validate model macros
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Add global macros first
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (can override global)
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -257,23 +297,40 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
@@ -282,29 +339,25 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Final pass: check if PORT macro is needed after macro expansion
|
||||
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||
// if it is required in either cmd or proxy keys
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort { // either has it
|
||||
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro and substitute it
|
||||
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute PORT in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
@@ -314,13 +367,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
@@ -328,35 +383,55 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
continue // replaced at runtime
|
||||
}
|
||||
// Reserved macros are always valid (they should have been substituted already)
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
// Any other macro is unknown
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unknown macros in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the proxy URL.
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf(
|
||||
"model %s: invalid proxy URL: %w", modelId, err,
|
||||
)
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
// if sendLoadingState is nil, set it to the global config value
|
||||
// see #366
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState // copy it
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
@@ -364,18 +439,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
@@ -383,7 +457,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// clean up hooks preload
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
@@ -395,10 +469,56 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -529,20 +649,26 @@ func validateMacro(name string, value any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -550,7 +676,7 @@ func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -609,3 +735,67 @@ func substituteMacroInValue(value any, macroName string, macroValue any) (any, e
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -58,6 +58,7 @@ models:
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
@@ -162,9 +163,20 @@ groups:
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
defaultTimeout := TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
@@ -184,6 +196,7 @@ groups:
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
@@ -192,6 +205,7 @@ groups:
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -200,6 +214,7 @@ groups:
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -208,10 +223,12 @@ groups:
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||
@@ -761,3 +762,785 @@ models:
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
content: `apiKeys: [""]`,
|
||||
expectedErr: "empty api key found in apiKeys",
|
||||
},
|
||||
{
|
||||
name: "blank spaces only",
|
||||
content: `apiKeys: [" "]`,
|
||||
expectedErr: "api key cannot contain spaces: ` `",
|
||||
},
|
||||
{
|
||||
name: "contains leading space",
|
||||
content: `apiKeys: [" key123"]`,
|
||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||
},
|
||||
{
|
||||
name: "contains trailing space",
|
||||
content: `apiKeys: ["key123 "]`,
|
||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||
},
|
||||
{
|
||||
name: "contains middle space",
|
||||
content: `apiKeys: ["key 123"]`,
|
||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||
},
|
||||
{
|
||||
name: "empty in list with valid keys",
|
||||
content: `apiKeys: ["valid-key", "", "another-key"]`,
|
||||
expectedErr: "empty api key found in apiKeys",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, tt.expectedErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_APIKeys_EnvMacros(t *testing.T) {
|
||||
t.Run("env substitution in apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "secret-key-123")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_API_KEY}"]`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"secret-key-123"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("multiple env substitutions in apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY_1", "key-one")
|
||||
t.Setenv("TEST_API_KEY_2", "key-two")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_API_KEY_1}", "${env.TEST_API_KEY_2}", "static-key"]`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"key-one", "key-two", "static-key"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("missing env var in apiKeys", func(t *testing.T) {
|
||||
content := `apiKeys: ["${env.NONEXISTENT_API_KEY}"]`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
// With string-level env substitution, error only includes var name
|
||||
assert.Contains(t, err.Error(), "NONEXISTENT_API_KEY")
|
||||
})
|
||||
|
||||
t.Run("env substitution results in empty key", func(t *testing.T) {
|
||||
t.Setenv("TEST_EMPTY_KEY", "")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_EMPTY_KEY}"]`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "empty api key found in apiKeys", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_GlobalTTL(t *testing.T) {
|
||||
t.Run("globalTTL sets default for models", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 300, config.GlobalTTL)
|
||||
assert.Equal(t, 300, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("model ttl=0 overrides globalTTL", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
ttl: 0
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("model explicit ttl overrides globalTTL", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
ttl: 600
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 600, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("globalTTL defaults to 0", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, config.GlobalTTL)
|
||||
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("negative globalTTL rejected", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: -1
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "globalTTL must be >= 0")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_EnvMacros(t *testing.T) {
|
||||
t.Run("basic env substitution in cmd", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_PATH", "/opt/models")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "${env.TEST_MODEL_PATH}/llama-server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/opt/models/llama-server", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env substitution in multiple fields", func(t *testing.T) {
|
||||
t.Setenv("TEST_HOST", "myserver")
|
||||
t.Setenv("TEST_PORT", "9999")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --host ${env.TEST_HOST}"
|
||||
proxy: "http://${env.TEST_HOST}:${env.TEST_PORT}"
|
||||
checkEndpoint: "http://${env.TEST_HOST}/health"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --host myserver", config.Models["test"].Cmd)
|
||||
assert.Equal(t, "http://myserver:9999", config.Models["test"].Proxy)
|
||||
assert.Equal(t, "http://myserver/health", config.Models["test"].CheckEndpoint)
|
||||
})
|
||||
|
||||
t.Run("env in global macro value", func(t *testing.T) {
|
||||
t.Setenv("TEST_BASE_PATH", "/usr/local")
|
||||
|
||||
content := `
|
||||
macros:
|
||||
SERVER_PATH: "${env.TEST_BASE_PATH}/bin/server"
|
||||
models:
|
||||
test:
|
||||
cmd: "${SERVER_PATH} --port 8080"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/usr/local/bin/server --port 8080", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env in model-level macro value", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_DIR", "/models/llama")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
MODEL_FILE: "${env.TEST_MODEL_DIR}/model.gguf"
|
||||
cmd: "server --model ${MODEL_FILE}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --model /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env in metadata", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "secret123")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
metadata:
|
||||
api_key: "${env.TEST_API_KEY}"
|
||||
nested:
|
||||
key: "${env.TEST_API_KEY}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "secret123", config.Models["test"].Metadata["api_key"])
|
||||
nested := config.Models["test"].Metadata["nested"].(map[string]any)
|
||||
assert.Equal(t, "secret123", nested["key"])
|
||||
})
|
||||
|
||||
t.Run("env in filters.stripParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_STRIP_PARAMS", "temperature,top_p")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
filters:
|
||||
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "temperature,top_p", config.Models["test"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("env in cmdStop", func(t *testing.T) {
|
||||
t.Setenv("TEST_KILL_SIGNAL", "SIGTERM")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --port ${PORT}"
|
||||
cmdStop: "kill -${env.TEST_KILL_SIGNAL} ${PID}"
|
||||
proxy: "http://localhost:${PORT}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, config.Models["test"].CmdStop, "-SIGTERM")
|
||||
})
|
||||
|
||||
t.Run("missing env var returns error", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "${env.UNDEFINED_VAR_12345}/server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_VAR_12345")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in global macro", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
PATH: "${env.UNDEFINED_GLOBAL_VAR}"
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_GLOBAL_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in model macro", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
MY_PATH: "${env.UNDEFINED_MODEL_VAR}"
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_MODEL_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in metadata", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
metadata:
|
||||
key: "${env.UNDEFINED_META_VAR}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_META_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env combined with regular macros", func(t *testing.T) {
|
||||
t.Setenv("TEST_ROOT", "/data")
|
||||
|
||||
content := `
|
||||
macros:
|
||||
MODEL_BASE: "${env.TEST_ROOT}/models"
|
||||
models:
|
||||
test:
|
||||
cmd: "server --model ${MODEL_BASE}/${MODEL_ID}.gguf"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --model /data/models/test.gguf", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("multiple env vars in same string", func(t *testing.T) {
|
||||
t.Setenv("TEST_USER", "admin")
|
||||
t.Setenv("TEST_PASS", "secret")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --auth ${env.TEST_USER}:${env.TEST_PASS}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --auth admin:secret", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env value with newline is rejected", func(t *testing.T) {
|
||||
t.Setenv("TEST_MULTILINE", "line1\nline2")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --config ${env.TEST_MULTILINE}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "TEST_MULTILINE")
|
||||
assert.Contains(t, err.Error(), "newlines")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env value with carriage return is rejected", func(t *testing.T) {
|
||||
t.Setenv("TEST_CR", "line1\rline2")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --config ${env.TEST_CR}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "TEST_CR")
|
||||
assert.Contains(t, err.Error(), "newlines")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env value with quotes is escaped for YAML", func(t *testing.T) {
|
||||
t.Setenv("TEST_QUOTED", `value with "quotes"`)
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --arg \"${env.TEST_QUOTED}\""
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
// Quotes are escaped before YAML parsing, then YAML unescapes them
|
||||
// Final result preserves the original value with quotes
|
||||
assert.Contains(t, config.Models["test"].Cmd, `"quotes"`)
|
||||
})
|
||||
|
||||
t.Run("env value with backslash is escaped for YAML", func(t *testing.T) {
|
||||
t.Setenv("TEST_BACKSLASH", `path\to\file`)
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --path \"${env.TEST_BACKSLASH}\""
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
// Backslashes are escaped before YAML parsing, then YAML unescapes them
|
||||
// Final result preserves the original value with backslashes
|
||||
assert.Contains(t, config.Models["test"].Cmd, `path\to\file`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_PeerApiKey_EnvMacros(t *testing.T) {
|
||||
t.Run("env substitution in peer apiKey", func(t *testing.T) {
|
||||
t.Setenv("TEST_PEER_API_KEY", "sk-peer-secret-123")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${env.TEST_PEER_API_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-peer-secret-123", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("missing env var in peer apiKey", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${env.NONEXISTENT_PEER_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
// With string-level env substitution, error only includes var name
|
||||
assert.Contains(t, err.Error(), "NONEXISTENT_PEER_KEY")
|
||||
})
|
||||
|
||||
t.Run("static apiKey unchanged", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-static-key
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-static-key", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("multiple peers with env apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_PEER_KEY_1", "key-one")
|
||||
t.Setenv("TEST_PEER_KEY_2", "key-two")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
peer1:
|
||||
proxy: https://peer1.example.com
|
||||
apiKey: "${env.TEST_PEER_KEY_1}"
|
||||
models:
|
||||
- model-a
|
||||
peer2:
|
||||
proxy: https://peer2.example.com
|
||||
apiKey: "${env.TEST_PEER_KEY_2}"
|
||||
models:
|
||||
- model-b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "key-one", config.Peers["peer1"].ApiKey)
|
||||
assert.Equal(t, "key-two", config.Peers["peer2"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("global macro substitution in peer apiKey", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
API_KEY: sk-from-global-macro
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${API_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-from-global-macro", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("global macro in peer filters.stripParams", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
STRIP_LIST: "temperature, top_p"
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
stripParams: "${STRIP_LIST}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "temperature, top_p", config.Peers["openrouter"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("global macro in peer filters.setParams", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
MAX_TOKENS: 4096
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
max_tokens: "${MAX_TOKENS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4096, config.Peers["openrouter"].Filters.SetParams["max_tokens"])
|
||||
})
|
||||
|
||||
t.Run("env macro in peer filters.setParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_RETENTION_POLICY", "deny")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
data_collection: "${env.TEST_RETENTION_POLICY}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "deny", config.Peers["openrouter"].Filters.SetParams["data_collection"])
|
||||
})
|
||||
|
||||
t.Run("env macro in peer filters.stripParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_STRIP_PARAMS", "frequency_penalty, presence_penalty")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "frequency_penalty, presence_penalty", config.Peers["openrouter"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("unknown macro in peer apiKey fails", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${UNDEFINED_MACRO}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peers.openrouter.apiKey")
|
||||
assert.Contains(t, err.Error(), "unknown macro")
|
||||
})
|
||||
|
||||
t.Run("unknown macro in peer filters.setParams fails", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
value: "${UNDEFINED_MACRO}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peers.openrouter.filters.setParams")
|
||||
assert.Contains(t, err.Error(), "unknown macro")
|
||||
})
|
||||
|
||||
t.Run("env macros in comments are ignored", func(t *testing.T) {
|
||||
content := `
|
||||
# apiKeys:
|
||||
# - "${env.COMMENTED_OUT_KEY_1}"
|
||||
# - "${env.COMMENTED_OUT_KEY_2}"
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
// These env vars are NOT set, but should not cause an error
|
||||
// because they only appear in comment lines
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("env macros in comments ignored while active ones resolve", func(t *testing.T) {
|
||||
t.Setenv("TEST_ACTIVE_KEY", "active-key-value")
|
||||
|
||||
content := `
|
||||
# apiKeys: ["${env.COMMENTED_OUT_KEY}"]
|
||||
apiKeys: ["${env.TEST_ACTIVE_KEY}"]
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"active-key-value"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("env macros in indented comments are ignored", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: |
|
||||
server
|
||||
--port 8080
|
||||
proxy: "http://localhost:8080"
|
||||
# metadata:
|
||||
# api_key: "${env.SOME_UNSET_KEY}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("env macros in inline comments are ignored", func(t *testing.T) {
|
||||
t.Setenv("TEST_INLINE_KEY", "real-value")
|
||||
|
||||
content := `
|
||||
apiKeys: ["${env.TEST_INLINE_KEY}"] # TODO: add ${env.FUTURE_KEY} later
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"real-value"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestConfig_TimeoutsParsing(t *testing.T) {
|
||||
configYaml := `
|
||||
models:
|
||||
model1:
|
||||
cmd: test-server --port ${PORT}
|
||||
timeouts:
|
||||
connect: 45
|
||||
responseHeader: 120
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
modelConfig, found := config.Models["model1"]
|
||||
require.True(t, found, "model1 should exist in config")
|
||||
|
||||
assert.Equal(t, 45, modelConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 120, modelConfig.Timeouts.ResponseHeader)
|
||||
}
|
||||
|
||||
func TestConfig_TimeoutsDefaults(t *testing.T) {
|
||||
configYaml := `
|
||||
models:
|
||||
model1:
|
||||
cmd: test-server --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
modelConfig, found := config.Models["model1"]
|
||||
require.True(t, found, "model1 should exist in config")
|
||||
|
||||
// Default values should be set during unmarshaling
|
||||
assert.Equal(t, 30, modelConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 0, modelConfig.Timeouts.ResponseHeader)
|
||||
assert.Equal(t, 10, modelConfig.Timeouts.TLSHandshake)
|
||||
assert.Equal(t, 1, modelConfig.Timeouts.ExpectContinue)
|
||||
assert.Equal(t, 90, modelConfig.Timeouts.IdleConn)
|
||||
}
|
||||
|
||||
func TestConfig_TimeoutsZeroAllowed(t *testing.T) {
|
||||
configYaml := `
|
||||
models:
|
||||
model1:
|
||||
cmd: test-server --port ${PORT}
|
||||
timeouts:
|
||||
connect: 0
|
||||
responseHeader: 0
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
modelConfig, found := config.Models["model1"]
|
||||
require.True(t, found, "model1 should exist in config")
|
||||
|
||||
// Explicit 0 should be preserved (disables timeout)
|
||||
assert.Equal(t, 0, modelConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 0, modelConfig.Timeouts.ResponseHeader)
|
||||
}
|
||||
|
||||
func TestConfig_PeerTimeoutsParsing(t *testing.T) {
|
||||
configYaml := `
|
||||
peers:
|
||||
peer1:
|
||||
proxy: http://example.com
|
||||
models: [model1]
|
||||
timeouts:
|
||||
connect: 45
|
||||
responseHeader: 120
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
peerConfig, found := config.Peers["peer1"]
|
||||
require.True(t, found, "peer1 should exist in config")
|
||||
|
||||
assert.Equal(t, 45, peerConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 120, peerConfig.Timeouts.ResponseHeader)
|
||||
}
|
||||
|
||||
func TestConfig_PeerTimeoutsDefaults(t *testing.T) {
|
||||
configYaml := `
|
||||
peers:
|
||||
peer1:
|
||||
proxy: http://example.com
|
||||
models: [model1]
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
peerConfig, found := config.Peers["peer1"]
|
||||
require.True(t, found, "peer1 should exist in config")
|
||||
|
||||
// Default values should be set during unmarshaling
|
||||
assert.Equal(t, 30, peerConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 60, peerConfig.Timeouts.ResponseHeader)
|
||||
assert.Equal(t, 10, peerConfig.Timeouts.TLSHandshake)
|
||||
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
|
||||
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
|
||||
}
|
||||
|
||||
@@ -55,6 +55,7 @@ models:
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
@@ -154,9 +155,20 @@ groups:
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
defaultTimeout := TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
@@ -170,6 +182,7 @@ groups:
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
@@ -179,6 +192,7 @@ groups:
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -188,6 +202,7 @@ groups:
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -197,10 +212,12 @@ groups:
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||
var ProtectedParams = []string{"model"}
|
||||
|
||||
// Filters contains filter settings for modifying request parameters
|
||||
// Used by both models and peers
|
||||
type Filters struct {
|
||||
// StripParams is a comma-separated list of parameters to remove from requests
|
||||
// The "model" parameter can never be removed
|
||||
StripParams string `yaml:"stripParams"`
|
||||
|
||||
// SetParams is a dictionary of parameters to set/override in requests
|
||||
// Protected params (like "model") cannot be set
|
||||
SetParams map[string]any `yaml:"setParams"`
|
||||
|
||||
// SetParamsByID maps requested model IDs to parameters to set/override in requests.
|
||||
// Useful with aliases: a single loaded model can behave differently depending on
|
||||
// which alias the client used. Applied after SetParams, so it can override those values.
|
||||
// Protected params (like "model") cannot be set.
|
||||
SetParamsByID map[string]map[string]any `yaml:"setParamsByID"`
|
||||
}
|
||||
|
||||
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||
// with duplicates, empty strings, and protected params removed
|
||||
func (f Filters) SanitizedStripParams() []string {
|
||||
if f.StripParams == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
// Skip protected params, empty strings, and duplicates
|
||||
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slices.Sort(cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// SanitizedSetParamsByID returns the params to set for the given requestedModelID,
|
||||
// with protected params removed and keys sorted for consistent iteration order.
|
||||
// Returns nil if the ID has no entry or all its params are protected.
|
||||
func (f Filters) SanitizedSetParamsByID(requestedModelID string) (map[string]any, []string) {
|
||||
if len(f.SetParamsByID) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
params, found := f.SetParamsByID[requestedModelID]
|
||||
if !found || len(params) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
result := make(map[string]any, len(params))
|
||||
keys := make([]string, 0, len(params))
|
||||
for key, value := range params {
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return result, keys
|
||||
}
|
||||
|
||||
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||
// and keys sorted for consistent iteration order
|
||||
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||
if len(f.SetParams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any, len(f.SetParams))
|
||||
keys := make([]string, 0, len(f.SetParams))
|
||||
|
||||
for key, value := range f.SetParams {
|
||||
// Skip protected params
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Sort keys for consistent ordering
|
||||
sort.Strings(keys)
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, keys
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripParams string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
stripParams: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single param",
|
||||
stripParams: "temperature",
|
||||
want: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
stripParams: "temperature, top_p, top_k",
|
||||
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||
},
|
||||
{
|
||||
name: "model param filtered",
|
||||
stripParams: "model, temperature, top_p",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "only model param",
|
||||
stripParams: "model",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "duplicates removed",
|
||||
stripParams: "temperature, top_p, temperature",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace",
|
||||
stripParams: " temperature , top_p ",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "empty values filtered",
|
||||
stripParams: "temperature,,top_p,",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{StripParams: tt.stripParams}
|
||||
got := f.SanitizedStripParams()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParams map[string]any
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty setParams",
|
||||
setParams: nil,
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
setParams: map[string]any{},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "normal params",
|
||||
setParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected model param filtered",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "complex nested values",
|
||||
setParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantKeys: []string{"provider", "transforms"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParams: tt.setParams}
|
||||
gotParams, gotKeys := f.SanitizedSetParams()
|
||||
|
||||
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||
for i, key := range gotKeys {
|
||||
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||
}
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams, "expected nil params")
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||
for key, wantValue := range tt.wantParams {
|
||||
gotValue, exists := gotParams[key]
|
||||
assert.True(t, exists, "missing key: %s", key)
|
||||
// Simple comparison for basic types
|
||||
switch v := wantValue.(type) {
|
||||
case string, int, float64, bool:
|
||||
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParamsByID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParamsByID map[string]map[string]any
|
||||
requestedModelID string
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty SetParamsByID returns nil",
|
||||
setParamsByID: nil,
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map returns nil",
|
||||
setParamsByID: map[string]map[string]any{},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "non-matching model ID returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model2": {"temperature": 0.9},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "matching model ID returns correct params",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {"temperature": 0.7, "top_p": 0.9},
|
||||
"model2": {"temperature": 0.5},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected param model is filtered out",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "keys are sorted",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
wantKeys: []string{"a_param", "m_param", "z_param"},
|
||||
},
|
||||
{
|
||||
name: "alias style key lookup",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1:high": {"reasoning_effort": "high"},
|
||||
"model1:low": {"reasoning_effort": "low"},
|
||||
},
|
||||
requestedModelID: "model1:high",
|
||||
wantParams: map[string]any{
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
wantKeys: []string{"reasoning_effort"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParamsByID: tt.setParamsByID}
|
||||
gotParams, gotKeys := f.SanitizedSetParamsByID(tt.requestedModelID)
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams)
|
||||
assert.Nil(t, gotKeys)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantKeys, gotKeys)
|
||||
assert.Equal(t, tt.wantParams, gotParams)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedParams(t *testing.T) {
|
||||
// Verify that "model" is protected
|
||||
assert.Contains(t, ProtectedParams, "model")
|
||||
}
|
||||
@@ -104,6 +104,62 @@ models:
|
||||
assert.Contains(t, err.Error(), "self-reference")
|
||||
}
|
||||
|
||||
// Test macro substitution in name and description fields
|
||||
func TestConfig_MacroInNameAndDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"VARIANT": "Q4_K_M"
|
||||
"FAMILY": "llama"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "${FAMILY} ${VARIANT}"
|
||||
description: "A ${FAMILY} model in ${VARIANT} format"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "llama Q4_K_M", config.Models["my-model"].Name)
|
||||
assert.Equal(t, "A llama model in Q4_K_M format", config.Models["my-model"].Description)
|
||||
}
|
||||
|
||||
// Test MODEL_ID macro in name and description fields
|
||||
func TestConfig_ModelIDInNameAndDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
llama-3b:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "Model: ${MODEL_ID}"
|
||||
description: "Running ${MODEL_ID}"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Model: llama-3b", config.Models["llama-3b"].Name)
|
||||
assert.Equal(t, "Running llama-3b", config.Models["llama-3b"].Description)
|
||||
}
|
||||
|
||||
// Test unknown macro in name or description returns an error
|
||||
func TestConfig_UnknownMacroInNameDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
test:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "Model ${UNDEFINED}"
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
|
||||
// Test undefined macro reference error
|
||||
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||
content := `
|
||||
|
||||
@@ -3,10 +3,23 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||
)
|
||||
|
||||
// TimeoutsConfig holds timeout settings for proxy connections
|
||||
// 0 = no timeout
|
||||
type TimeoutsConfig struct {
|
||||
Connect int `yaml:"connect"`
|
||||
KeepAlive int `yaml:"keepalive"`
|
||||
ResponseHeader int `yaml:"responseHeader"`
|
||||
TLSHandshake int `yaml:"tlsHandshake"`
|
||||
ExpectContinue int `yaml:"expectContinue"`
|
||||
IdleConn int `yaml:"idleConn"`
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
@@ -38,6 +51,9 @@ type ModelConfig struct {
|
||||
|
||||
// override global setting
|
||||
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
@@ -49,12 +65,22 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
|
||||
// matches http.DefaultTransport
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
@@ -74,16 +100,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||
// See issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"stripParams"`
|
||||
Filters `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
defaults := rawModelFilters{}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
@@ -104,25 +129,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||
// Returns ([]string, error) to match existing API
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
return f.Filters.SanitizedStripParams(), nil
|
||||
}
|
||||
|
||||
@@ -72,3 +72,101 @@ models:
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAlias(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"${MODEL_ID}:high":
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
reasoning_effort: low
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Keys (other than the model's own ID) should be registered as aliases
|
||||
realName, found := cfg.RealModelName("model1:high")
|
||||
assert.True(t, found, "model1:high should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
realName, found = cfg.RealModelName("model1:low")
|
||||
assert.True(t, found, "model1:low should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
// Auto-aliases should also appear in modelConfig.Aliases
|
||||
aliases := cfg.Models["model1"].Aliases
|
||||
assert.Contains(t, aliases, "model1:high")
|
||||
assert.Contains(t, aliases, "model1:low")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithModelID(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
model2:
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "conflicts with an existing model ID")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithOtherModel(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: low
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "duplicate alias")
|
||||
}
|
||||
|
||||
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
stripParams: "top_k"
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelConfig := config.Models["model1"]
|
||||
|
||||
// Check stripParams
|
||||
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||
|
||||
// Check setParams
|
||||
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||
assert.NotNil(t, setParams)
|
||||
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type PeerDictionaryConfig map[string]PeerConfig
|
||||
type PeerConfig struct {
|
||||
Proxy string `yaml:"proxy"`
|
||||
ProxyURL *url.URL `yaml:"-"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Models []string `yaml:"models"`
|
||||
Filters Filters `yaml:"filters"`
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
}
|
||||
|
||||
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawPeerConfig PeerConfig
|
||||
defaults := rawPeerConfig{
|
||||
Proxy: "",
|
||||
ApiKey: "",
|
||||
Models: []string{},
|
||||
Filters: Filters{},
|
||||
|
||||
// mostly matches http.DefaultTransport but with a 60s ResponseHeader timeout
|
||||
// to match the pre PR #619 functionality
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate proxy is not empty
|
||||
if defaults.Proxy == "" {
|
||||
return fmt.Errorf("proxy is required")
|
||||
}
|
||||
|
||||
// Validate proxy is a valid URL and store the parsed value
|
||||
parsedURL, err := url.Parse(defaults.Proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||
}
|
||||
defaults.ProxyURL = parsedURL
|
||||
|
||||
// Validate models is not empty
|
||||
if len(defaults.Models) == 0 {
|
||||
return fmt.Errorf("peer models can not be empty")
|
||||
}
|
||||
|
||||
*c = PeerConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
yaml: `
|
||||
proxy: http://192.168.1.23
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid config with apiKey",
|
||||
yaml: `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "missing proxy",
|
||||
yaml: `
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "empty proxy",
|
||||
yaml: `
|
||||
proxy: ""
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "invalid proxy URL",
|
||||
yaml: `
|
||||
proxy: "://invalid"
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "invalid peer proxy URL",
|
||||
},
|
||||
{
|
||||
name: "missing models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
models: []
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||
} else if !contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: http://192.168.1.23:8080/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.ProxyURL == nil {
|
||||
t.Fatal("ProxyURL should not be nil")
|
||||
}
|
||||
|
||||
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Scheme != "http" {
|
||||
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Path != "/api" {
|
||||
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||
}
|
||||
|
||||
func searchSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
provider:
|
||||
data_collection: deny
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
|
||||
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||
}
|
||||
|
||||
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider should be a map")
|
||||
}
|
||||
if provider["data_collection"] != "deny" {
|
||||
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
stripParams: "temperature, top_p"
|
||||
setParams:
|
||||
max_tokens: 1000
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check stripParams
|
||||
stripParams := config.Filters.SanitizedStripParams()
|
||||
if len(stripParams) != 2 {
|
||||
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||
}
|
||||
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||
t.Errorf("unexpected strip params: %v", stripParams)
|
||||
}
|
||||
|
||||
// Check setParams
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
@@ -58,3 +59,11 @@ type ModelPreloadedEvent struct {
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
|
||||
@@ -71,11 +71,15 @@ func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Convert path to forward slashes for cross-platform compatibility
|
||||
// Windows handles forward slashes in paths correctly
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, simpleResponderPath, port, expectedMessage, port)
|
||||
`, cmdPath, port, expectedMessage, port)
|
||||
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
|
||||
@@ -1,16 +1,95 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
// circularBuffer is a fixed-size circular byte buffer that overwrites
|
||||
// oldest data when full. It provides O(1) writes and O(n) reads.
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
|
||||
func newCircularBuffer(capacity int) *circularBuffer {
|
||||
return &circularBuffer{
|
||||
data: make([]byte, capacity),
|
||||
head: 0,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Write appends bytes to the buffer, overwriting oldest data when full.
|
||||
// Data is copied into the internal buffer (not stored by reference).
|
||||
func (cb *circularBuffer) Write(p []byte) {
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cap := len(cb.data)
|
||||
|
||||
// If input is larger than capacity, only keep the last cap bytes
|
||||
if len(p) >= cap {
|
||||
copy(cb.data, p[len(p)-cap:])
|
||||
cb.head = 0
|
||||
cb.size = cap
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate how much space is available from head to end of buffer
|
||||
firstPart := cap - cb.head
|
||||
if firstPart >= len(p) {
|
||||
// All data fits without wrapping
|
||||
copy(cb.data[cb.head:], p)
|
||||
cb.head = (cb.head + len(p)) % cap
|
||||
} else {
|
||||
// Data wraps around
|
||||
copy(cb.data[cb.head:], p[:firstPart])
|
||||
copy(cb.data[:len(p)-firstPart], p[firstPart:])
|
||||
cb.head = len(p) - firstPart
|
||||
}
|
||||
|
||||
// Update size
|
||||
cb.size += len(p)
|
||||
if cb.size > cap {
|
||||
cb.size = cap
|
||||
}
|
||||
}
|
||||
|
||||
// GetHistory returns all buffered data in correct order (oldest to newest).
|
||||
// Returns a new slice (copy), not a view into internal buffer.
|
||||
func (cb *circularBuffer) GetHistory() []byte {
|
||||
if cb.size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]byte, cb.size)
|
||||
cap := len(cb.data)
|
||||
|
||||
// Calculate start position (oldest data)
|
||||
start := (cb.head - cb.size + cap) % cap
|
||||
|
||||
if start+cb.size <= cap {
|
||||
// Data is contiguous, single copy
|
||||
copy(result, cb.data[start:start+cb.size])
|
||||
} else {
|
||||
// Data wraps around, two copies
|
||||
firstPart := cap - start
|
||||
copy(result[:firstPart], cb.data[start:])
|
||||
copy(result[firstPart:], cb.data[:cb.size-firstPart])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
@@ -18,12 +97,14 @@ const (
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
|
||||
LogBufferSize = 100 * 1024
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
eventbus *event.Dispatcher
|
||||
mu sync.RWMutex
|
||||
buffer *ring.Ring
|
||||
buffer *circularBuffer
|
||||
bufferMu sync.RWMutex
|
||||
|
||||
// typically this can be os.Stdout
|
||||
@@ -32,6 +113,9 @@ type LogMonitor struct {
|
||||
// logging levels
|
||||
level LogLevel
|
||||
prefix string
|
||||
|
||||
// timestamps
|
||||
timeFormat string
|
||||
}
|
||||
|
||||
func NewLogMonitor() *LogMonitor {
|
||||
@@ -40,11 +124,12 @@ func NewLogMonitor() *LogMonitor {
|
||||
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: nil, // lazy initialized on first Write
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
timeFormat: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
w.bufferMu.Lock()
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.buffer.Value = bufferCopy
|
||||
w.buffer = w.buffer.Next()
|
||||
if w.buffer == nil {
|
||||
w.buffer = newCircularBuffer(LogBufferSize)
|
||||
}
|
||||
w.buffer.Write(p)
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
// Make a copy for broadcast to preserve immutability
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.broadcast(bufferCopy)
|
||||
return n, nil
|
||||
}
|
||||
@@ -72,16 +160,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
func (w *LogMonitor) GetHistory() []byte {
|
||||
w.bufferMu.RLock()
|
||||
defer w.bufferMu.RUnlock()
|
||||
if w.buffer == nil {
|
||||
return nil
|
||||
}
|
||||
return w.buffer.GetHistory()
|
||||
}
|
||||
|
||||
var history []byte
|
||||
w.buffer.Do(func(p any) {
|
||||
if p != nil {
|
||||
if content, ok := p.([]byte); ok {
|
||||
history = append(history, content...)
|
||||
}
|
||||
}
|
||||
})
|
||||
return history
|
||||
// Clear releases the buffer memory, making it eligible for GC.
|
||||
// The buffer will be lazily re-allocated on the next Write.
|
||||
func (w *LogMonitor) Clear() {
|
||||
w.bufferMu.Lock()
|
||||
w.buffer = nil
|
||||
w.bufferMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||
@@ -106,12 +196,22 @@ func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
||||
w.level = level
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.timeFormat = timeFormat
|
||||
}
|
||||
|
||||
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||
prefix := ""
|
||||
if w.prefix != "" {
|
||||
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
|
||||
timestamp := ""
|
||||
if w.timeFormat != "" {
|
||||
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||
|
||||
@@ -3,8 +3,10 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLogMonitor(t *testing.T) {
|
||||
@@ -84,3 +86,231 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
|
||||
t.Errorf("Expected history to be %q, got %q", expected, history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite_LogTimeFormat(t *testing.T) {
|
||||
// Create a new LogMonitor instance
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Enable timestamps
|
||||
lm.timeFormat = time.RFC3339
|
||||
|
||||
// Write the message to the LogMonitor
|
||||
lm.Info("Hello, World!")
|
||||
|
||||
// Get the history from the LogMonitor
|
||||
history := lm.GetHistory()
|
||||
|
||||
timestamp := ""
|
||||
fields := strings.Fields(string(history))
|
||||
if len(fields) > 0 {
|
||||
timestamp = fields[0]
|
||||
} else {
|
||||
t.Fatalf("Cannot extract string from history")
|
||||
}
|
||||
|
||||
_, err := time.Parse(time.RFC3339, timestamp)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot find timestamp: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_WrapAround(t *testing.T) {
|
||||
// Create a small buffer to test wrap-around
|
||||
cb := newCircularBuffer(10)
|
||||
|
||||
// Write "hello" (5 bytes)
|
||||
cb.Write([]byte("hello"))
|
||||
if got := string(cb.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Write "world" (5 bytes) - buffer now full
|
||||
cb.Write([]byte("world"))
|
||||
if got := string(cb.GetHistory()); got != "helloworld" {
|
||||
t.Errorf("Expected 'helloworld', got %q", got)
|
||||
}
|
||||
|
||||
// Write "12345" (5 bytes) - should overwrite "hello"
|
||||
cb.Write([]byte("12345"))
|
||||
if got := string(cb.GetHistory()); got != "world12345" {
|
||||
t.Errorf("Expected 'world12345', got %q", got)
|
||||
}
|
||||
|
||||
// Write data larger than buffer capacity
|
||||
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
|
||||
if got := string(cb.GetHistory()); got != "ghijklmnop" {
|
||||
t.Errorf("Expected 'ghijklmnop', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
|
||||
// Test empty buffer
|
||||
cb := newCircularBuffer(10)
|
||||
if got := cb.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil for empty buffer, got %q", got)
|
||||
}
|
||||
|
||||
// Test exact capacity
|
||||
cb.Write([]byte("1234567890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
|
||||
// Test write exactly at capacity boundary
|
||||
cb = newCircularBuffer(10)
|
||||
cb.Write([]byte("12345"))
|
||||
cb.Write([]byte("67890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_LazyInit(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Buffer should be nil before any writes
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil before first write")
|
||||
}
|
||||
|
||||
// GetHistory should return nil when buffer is nil
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history before first write, got %q", got)
|
||||
}
|
||||
|
||||
// Write should lazily initialize the buffer
|
||||
lm.Write([]byte("test"))
|
||||
|
||||
if lm.buffer == nil {
|
||||
t.Error("Expected buffer to be initialized after write")
|
||||
}
|
||||
|
||||
if got := string(lm.GetHistory()); got != "test" {
|
||||
t.Errorf("Expected 'test', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_Clear(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write some data
|
||||
lm.Write([]byte("hello"))
|
||||
if got := string(lm.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Clear should release the buffer
|
||||
lm.Clear()
|
||||
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil after Clear")
|
||||
}
|
||||
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history after Clear, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_ClearAndReuse(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write, clear, then write again
|
||||
lm.Write([]byte("first"))
|
||||
lm.Clear()
|
||||
lm.Write([]byte("second"))
|
||||
|
||||
if got := string(lm.GetHistory()); got != "second" {
|
||||
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogMonitorWrite(b *testing.B) {
|
||||
// Test data of varying sizes
|
||||
smallMsg := []byte("small message\n")
|
||||
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
|
||||
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
|
||||
|
||||
b.Run("SmallWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(smallMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("MediumWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LargeWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(largeMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WithSubscribers", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Add some subscribers
|
||||
for i := 0; i < 5; i++ {
|
||||
lm.OnLogData(func(data []byte) {})
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetHistory", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Pre-populate with data
|
||||
for i := 0; i < 1000; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.GetHistory()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
Benchmark Results - MBP M1 Pro
|
||||
|
||||
Before (ring.Ring):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 43 ns | 40 B | 2 |
|
||||
| MediumWrite (241B) | 76 ns | 264 B | 2 |
|
||||
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
|
||||
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
|
||||
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
|
||||
|
||||
After (circularBuffer 10KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 67 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
|
||||
|
||||
After (circularBuffer 100KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|-----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 66 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
|
||||
|
||||
Summary:
|
||||
- GetHistory: 139x faster (10KB), 18x faster (100KB)
|
||||
- Allocations: reduced from 2 to 1 across all operations
|
||||
- Small/medium writes: ~1.1-1.6x faster
|
||||
*/
|
||||
|
||||
@@ -2,6 +2,8 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -26,6 +28,28 @@ type TokenMetrics struct {
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// Size returns the approximate memory usage of this capture in bytes
|
||||
func (c *ReqRespCapture) Size() int {
|
||||
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
|
||||
for k, v := range c.ReqHeaders {
|
||||
size += len(k) + len(v)
|
||||
}
|
||||
for k, v := range c.RespHeaders {
|
||||
size += len(k) + len(v)
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
@@ -44,19 +68,32 @@ type metricsMonitor struct {
|
||||
maxMetrics int
|
||||
nextID int
|
||||
logger *LogMonitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captures map[int]ReqRespCapture // map for O(1) lookup by ID
|
||||
captureOrder []int // track insertion order for FIFO eviction
|
||||
captureSize int // current total size in bytes
|
||||
maxCaptureSize int // max bytes for captures
|
||||
}
|
||||
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
|
||||
mp := &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
return &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
captures: make(map[int]ReqRespCapture),
|
||||
captureOrder: make([]int, 0),
|
||||
captureSize: 0,
|
||||
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
// addMetrics adds a new metric to the collection and publishes an event.
|
||||
// Returns the assigned metric ID.
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
@@ -67,6 +104,49 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||
if !mp.enableCaptures {
|
||||
return
|
||||
}
|
||||
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
captureSize := capture.Size()
|
||||
if captureSize > mp.maxCaptureSize {
|
||||
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest (FIFO) until room available
|
||||
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||
oldestID := mp.captureOrder[0]
|
||||
mp.captureOrder = mp.captureOrder[1:]
|
||||
if evicted, exists := mp.captures[oldestID]; exists {
|
||||
mp.captureSize -= evicted.Size()
|
||||
delete(mp.captures, oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
mp.captures[capture.ID] = capture
|
||||
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||
mp.captureSize += captureSize
|
||||
}
|
||||
|
||||
// getCaptureByID returns a capture by its ID, or nil if not found.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
if capture, exists := mp.captures[id]; exists {
|
||||
return &capture
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics
|
||||
@@ -95,7 +175,35 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
request *http.Request,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||
}
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
reqHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -108,30 +216,94 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize default metrics - these will always be recorded
|
||||
tm := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics skipped, empty body")
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsed
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
mp.addMetrics(tm)
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// extract timings for infill - response is an array, timings are in the last element
|
||||
// see #463
|
||||
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsedMetrics
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
respHeaders := make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: body,
|
||||
}
|
||||
// Only set HasCapture if the capture will actually be stored (not too large)
|
||||
if capture.Size() <= mp.maxCaptureSize {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.addMetrics(tm)
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
mp.addCapture(*capture)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -174,19 +346,27 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
|
||||
}
|
||||
|
||||
if gjson.ValidBytes(data) {
|
||||
return parseMetrics(modelID, start, gjson.ParseBytes(data))
|
||||
parsed := gjson.ParseBytes(data)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// v1/responses format nests usage under response.usage
|
||||
if !usage.Exists() {
|
||||
usage = parsed.Get("response.usage")
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
return parseMetrics(modelID, start, usage, timings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) {
|
||||
usage := jsonData.Get("usage")
|
||||
timings := jsonData.Get("timings")
|
||||
if !usage.Exists() && !timings.Exists() {
|
||||
return TokenMetrics{}, fmt.Errorf("no usage or timings data found")
|
||||
}
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
|
||||
// default values
|
||||
cachedTokens := -1 // unknown or missing data
|
||||
outputTokens := 0
|
||||
@@ -195,22 +375,41 @@ func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (Token
|
||||
// timings data
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
durationMs := int(time.Since(start).Milliseconds())
|
||||
durationMs := wallDurationMs
|
||||
|
||||
if usage.Exists() {
|
||||
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
||||
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
||||
if pt := usage.Get("prompt_tokens"); pt.Exists() {
|
||||
// v1/chat/completions
|
||||
inputTokens = int(pt.Int())
|
||||
} else if it := usage.Get("input_tokens"); it.Exists() {
|
||||
// v1/messages
|
||||
inputTokens = int(it.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("completion_tokens"); ct.Exists() {
|
||||
// v1/chat/completions
|
||||
outputTokens = int(ct.Int())
|
||||
} else if ot := usage.Get("output_tokens"); ot.Exists() {
|
||||
outputTokens = int(ot.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
|
||||
cachedTokens = int(ct.Int())
|
||||
}
|
||||
}
|
||||
|
||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||
if timings.Exists() {
|
||||
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
||||
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
||||
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
||||
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||
inputTokens = int(timings.Get("prompt_n").Int())
|
||||
outputTokens = int(timings.Get("predicted_n").Int())
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
if timingsDurationMs > durationMs {
|
||||
durationMs = timingsDurationMs
|
||||
}
|
||||
|
||||
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = int(cachedValue.Int())
|
||||
}
|
||||
}
|
||||
@@ -227,6 +426,25 @@ func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (Token
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
@@ -265,3 +483,43 @@ func (w *responseBodyCopier) Header() http.Header {
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that should be redacted in captures
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -11,11 +14,12 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -34,7 +38,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
mm.addMetrics(TokenMetrics{Model: "model"})
|
||||
@@ -48,7 +52,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("respects max metrics limit", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 3)
|
||||
mm := newMetricsMonitor(testLogger, 3, 0)
|
||||
|
||||
// Add 5 metrics
|
||||
for i := 0; i < 5; i++ {
|
||||
@@ -68,7 +72,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
receivedEvent := make(chan TokenMetricsEvent, 1)
|
||||
cancel := event.On(func(e TokenMetricsEvent) {
|
||||
@@ -98,14 +102,14 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||
t.Run("returns empty slice when no metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
metrics := mm.getMetrics()
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
mm.addMetrics(TokenMetrics{Model: "model1"})
|
||||
mm.addMetrics(TokenMetrics{Model: "model2"})
|
||||
|
||||
@@ -125,7 +129,7 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
jsonData, err := mm.getMetricsJSON()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jsonData)
|
||||
@@ -137,7 +141,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model1",
|
||||
InputTokens: 100,
|
||||
@@ -165,7 +169,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"usage": {
|
||||
@@ -196,7 +200,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("successful request with timings data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"timings": {
|
||||
@@ -236,7 +240,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("streaming request with SSE format", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Note: SSE format requires proper line breaks - each data line followed by blank line
|
||||
responseBody := `data: {"choices":[{"text":"Hello"}]}
|
||||
@@ -272,7 +276,7 @@ data: [DONE]
|
||||
})
|
||||
|
||||
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -291,8 +295,8 @@ data: [DONE]
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("empty response body does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("empty response body records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -307,11 +311,14 @@ data: [DONE]
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -328,11 +335,14 @@ data: [DONE]
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
expectedErr := assert.AnError
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -350,8 +360,8 @@ data: [DONE]
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("response without usage or timings records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"result": "ok"}`
|
||||
|
||||
@@ -367,10 +377,82 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Infill response is an array with timings in the last element
|
||||
responseBody := `[
|
||||
{"content": "first chunk"},
|
||||
{"content": "second chunk"},
|
||||
{"content": "final", "timings": {
|
||||
"prompt_n": 150,
|
||||
"predicted_n": 75,
|
||||
"prompt_per_second": 200.5,
|
||||
"predicted_per_second": 35.5,
|
||||
"prompt_ms": 600.0,
|
||||
"predicted_ms": 1800.0,
|
||||
"cache_n": 30
|
||||
}}
|
||||
]`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/infill", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 150, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 30, metrics[0].CachedTokens)
|
||||
assert.Equal(t, 200.5, metrics[0].PromptPerSecond)
|
||||
assert.Equal(t, 35.5, metrics[0].TokensPerSecond)
|
||||
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
|
||||
})
|
||||
|
||||
t.Run("infill request with empty array records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `[]`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/infill", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -425,7 +507,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 1000)
|
||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
@@ -452,7 +534,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 100)
|
||||
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
@@ -489,8 +571,29 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
t.Run("keeps wall clock duration when timings underreport request time", func(t *testing.T) {
|
||||
start := time.Now().Add(-5 * time.Second)
|
||||
usage := gjson.Parse(`{"prompt_tokens": 5, "completion_tokens": 1}`)
|
||||
timings := gjson.Parse(`{
|
||||
"prompt_n": 5,
|
||||
"predicted_n": 1,
|
||||
"prompt_per_second": 10.0,
|
||||
"predicted_per_second": 2.0,
|
||||
"prompt_ms": 5.0,
|
||||
"predicted_ms": 15.0
|
||||
}`)
|
||||
|
||||
metrics, err := parseMetrics("test-model", start, usage, timings)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, metrics.InputTokens)
|
||||
assert.Equal(t, 1, metrics.OutputTokens)
|
||||
assert.Equal(t, 10.0, metrics.PromptPerSecond)
|
||||
assert.Equal(t, 2.0, metrics.TokensPerSecond)
|
||||
assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
|
||||
})
|
||||
|
||||
t.Run("prefers timings over usage data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Timings should take precedence over usage
|
||||
responseBody := `{
|
||||
@@ -530,7 +633,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"timings": {
|
||||
@@ -565,7 +668,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
|
||||
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Metrics should be found in the last data line before [DONE]
|
||||
responseBody := `data: {"choices":[{"text":"First"}]}
|
||||
@@ -598,8 +701,8 @@ data: [DONE]
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `data: not json
|
||||
|
||||
@@ -619,14 +722,46 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles empty streaming response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// v1/responses SSE format: usage is nested under response.usage
|
||||
responseBody := "event: response.completed\n" +
|
||||
`data: {"type":"response.completed","response":{"id":"resp_abc","object":"response","created_at":1773416985,"status":"completed","model":"test-model","output":[],"usage":{"input_tokens":17,"output_tokens":23,"total_tokens":40}}}` +
|
||||
"\n\n"
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 17, metrics[0].InputTokens)
|
||||
assert.Equal(t, 23, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := ``
|
||||
|
||||
@@ -642,17 +777,19 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
// Empty body should not trigger WrapHandler processing
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||
mm := newMetricsMonitor(testLogger, 1000)
|
||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -673,7 +810,7 @@ func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||
|
||||
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||
mm := newMetricsMonitor(testLogger, 100)
|
||||
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -691,3 +828,352 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
mm.addMetrics(metric)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||
t.Run("gzip encoded response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
// Compress with gzip
|
||||
var buf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&buf)
|
||||
gzWriter.Write([]byte(responseBody))
|
||||
gzWriter.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(compressedBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("deflate encoded response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}`
|
||||
|
||||
// Compress with deflate
|
||||
var buf bytes.Buffer
|
||||
flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression)
|
||||
flateWriter.Write([]byte(responseBody))
|
||||
flateWriter.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "deflate")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(compressedBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 200, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Invalid compressed data
|
||||
invalidData := []byte("this is not gzip data")
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(invalidData)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Should not return error, just log warning
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "unknown-encoding")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 300, metrics[0].InputTokens)
|
||||
assert.Equal(t, 100, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReqRespCapture_Size(t *testing.T) {
|
||||
t.Run("calculates size correctly", func(t *testing.T) {
|
||||
capture := ReqRespCapture{
|
||||
ID: 1,
|
||||
ReqPath: "/v1/chat/completions", // 20 bytes
|
||||
ReqHeaders: map[string]string{
|
||||
"Content-Type": "application/json", // 12 + 16 = 28
|
||||
},
|
||||
ReqBody: []byte("request body"), // 12 bytes
|
||||
RespHeaders: map[string]string{
|
||||
"X-Test": "value", // 6 + 5 = 11
|
||||
},
|
||||
RespBody: []byte("response body"), // 13 bytes
|
||||
}
|
||||
|
||||
// Expected: 20 + 12 + 13 + 28 + 11 = 84
|
||||
assert.Equal(t, 84, capture.Size())
|
||||
})
|
||||
|
||||
t.Run("handles empty capture", func(t *testing.T) {
|
||||
capture := ReqRespCapture{}
|
||||
assert.Equal(t, 0, capture.Size())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_AddCapture(t *testing.T) {
|
||||
t.Run("does nothing when captures disabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 0,
|
||||
ReqBody: []byte("test"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
// Should not store capture
|
||||
assert.Nil(t, mm.getCaptureByID(0))
|
||||
})
|
||||
|
||||
t.Run("adds capture when enabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 0,
|
||||
ReqBody: []byte("test request"),
|
||||
RespBody: []byte("test response"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
retrieved := mm.getCaptureByID(0)
|
||||
assert.NotNil(t, retrieved)
|
||||
assert.Equal(t, 0, retrieved.ID)
|
||||
assert.Equal(t, []byte("test request"), retrieved.ReqBody)
|
||||
assert.Equal(t, []byte("test response"), retrieved.RespBody)
|
||||
})
|
||||
|
||||
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
mm.maxCaptureSize = 100 // Set small limit for test
|
||||
|
||||
// Add captures that will exceed the limit
|
||||
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)}
|
||||
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)}
|
||||
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)}
|
||||
|
||||
mm.addCapture(capture1)
|
||||
mm.addCapture(capture2)
|
||||
// Adding capture3 should evict capture1
|
||||
mm.addCapture(capture3)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
|
||||
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
|
||||
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
|
||||
})
|
||||
|
||||
t.Run("skips capture larger than max size", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
mm.maxCaptureSize = 100
|
||||
|
||||
// Add a capture larger than max
|
||||
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)}
|
||||
mm.addCapture(largeCapture)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
||||
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(999))
|
||||
})
|
||||
|
||||
t.Run("returns capture by ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 42,
|
||||
ReqBody: []byte("test"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
retrieved := mm.getCaptureByID(42)
|
||||
assert.NotNil(t, retrieved)
|
||||
assert.Equal(t, 42, retrieved.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
t.Run("redacts sensitive headers", func(t *testing.T) {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer secret-token",
|
||||
"Proxy-Authorization": "Basic creds",
|
||||
"Cookie": "session=abc123",
|
||||
"Set-Cookie": "session=xyz789",
|
||||
"X-Api-Key": "sk-12345",
|
||||
"Content-Type": "application/json",
|
||||
"X-Custom": "safe-value",
|
||||
}
|
||||
|
||||
redactHeaders(headers)
|
||||
|
||||
assert.Equal(t, "[REDACTED]", headers["Authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Proxy-Authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Cookie"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Set-Cookie"])
|
||||
assert.Equal(t, "[REDACTED]", headers["X-Api-Key"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
assert.Equal(t, "safe-value", headers["X-Custom"])
|
||||
})
|
||||
|
||||
t.Run("handles mixed case header names", func(t *testing.T) {
|
||||
headers := map[string]string{
|
||||
"authorization": "Bearer token",
|
||||
"COOKIE": "session=abc",
|
||||
"x-api-key": "key123",
|
||||
}
|
||||
|
||||
redactHeaders(headers)
|
||||
|
||||
assert.Equal(t, "[REDACTED]", headers["authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["COOKIE"])
|
||||
assert.Equal(t, "[REDACTED]", headers["x-api-key"])
|
||||
})
|
||||
|
||||
t.Run("handles empty headers", func(t *testing.T) {
|
||||
headers := map[string]string{}
|
||||
redactHeaders(headers)
|
||||
assert.Empty(t, headers)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
||||
t.Run("captures request and response when enabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
requestBody := `{"model": "test", "prompt": "hello"}`
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Custom", "header-value")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer secret")
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check metric was recorded
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
metricID := metrics[0].ID
|
||||
|
||||
// Check capture was stored with same ID
|
||||
capture := mm.getCaptureByID(metricID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Equal(t, metricID, capture.ID)
|
||||
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||
assert.Equal(t, "/test", capture.ReqPath)
|
||||
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
|
||||
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
|
||||
})
|
||||
|
||||
t.Run("does not capture when disabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
requestBody := `{"model": "test"}`
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Metrics should still be recorded
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
|
||||
// But no capture
|
||||
capture := mm.getCaptureByID(metrics[0].ID)
|
||||
assert.Nil(t, capture)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
|
||||
// Create a transport with per-peer timeout configuration
|
||||
peerTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"test-peer": config.PeerConfig{
|
||||
Proxy: "http://localhost:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"model1"},
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 300,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(peers, testLogger)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, peerProxy)
|
||||
assert.True(t, peerProxy.HasPeerModel("model1"))
|
||||
|
||||
// Verify the timeout values are actually applied to the transport
|
||||
member, found := peerProxy.proxyMap["model1"]
|
||||
require.True(t, found, "model1 should exist in proxyMap")
|
||||
assert.NotNil(t, member.reverseProxy)
|
||||
assert.NotNil(t, member.reverseProxy.Transport)
|
||||
|
||||
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||
require.True(t, ok, "Transport should be *http.Transport")
|
||||
|
||||
// Verify all timeout values are correctly applied
|
||||
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
// ForceAttemptHTTP2 should be enabled
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
@@ -96,6 +96,24 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
|
||||
var reverseProxy *httputil.ReverseProxy
|
||||
if proxyURL != nil {
|
||||
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
|
||||
// Create custom transport with configured timeouts
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
reverseProxy.Transport = transport
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
@@ -256,6 +274,7 @@ func (p *Process) start() error {
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
setProcAttributes(p.cmd)
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
@@ -413,6 +432,9 @@ func (p *Process) stopCommand() {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
|
||||
// free the buffer in processLogger so the memory can be recovered
|
||||
p.processLogger.Clear()
|
||||
}()
|
||||
|
||||
p.cmdMutex.RLock()
|
||||
@@ -506,7 +528,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// add a sync so the streaming client only runs when the goroutine has exited
|
||||
|
||||
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming {
|
||||
|
||||
// PR #417 (no support for anthropic v1/messages yet)
|
||||
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
|
||||
srw = newStatusResponseWriter(p, w)
|
||||
go srw.statusUpdates(swapCtx)
|
||||
} else {
|
||||
@@ -625,6 +650,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Env = p.cmd.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
@@ -641,6 +667,11 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger returns the logger for this process.
|
||||
func (p *Process) Logger() *LogMonitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
@@ -733,6 +764,14 @@ func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
// Recover from panics caused by client disconnection
|
||||
// Note: recover() only works within the same goroutine, so we need it here
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(s.start)
|
||||
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||
@@ -851,7 +890,6 @@ func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
// Add Flush method
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -117,12 +118,12 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
}
|
||||
|
||||
expectedMessage := "I_sense_imminent_danger"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -159,12 +160,12 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
config := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
conf := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
@@ -395,6 +396,10 @@ func TestProcess_StopImmediately(t *testing.T) {
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
@@ -565,3 +570,39 @@ func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||
}
|
||||
return w.ResponseRecorder.Write(b)
|
||||
}
|
||||
|
||||
func TestProcess_CustomTimeouts(t *testing.T) {
|
||||
modelConfig := config.ModelConfig{
|
||||
Cmd: "echo test",
|
||||
Proxy: "http://localhost:8080",
|
||||
CheckEndpoint: "/health",
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 120,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
}
|
||||
|
||||
debugLogger := NewLogMonitorWriter(io.Discard)
|
||||
process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger)
|
||||
|
||||
// Verify the process was created successfully
|
||||
assert.NotNil(t, process)
|
||||
assert.Equal(t, "test-model", process.ID)
|
||||
assert.NotNil(t, process.reverseProxy)
|
||||
assert.NotNil(t, process.reverseProxy.Transport)
|
||||
|
||||
// Verify it's using http.Transport (not some other type)
|
||||
transport, ok := process.reverseProxy.Transport.(*http.Transport)
|
||||
assert.True(t, ok, "Transport should be *http.Transport")
|
||||
assert.NotNil(t, transport)
|
||||
|
||||
// Verify the timeouts are correctly applied
|
||||
assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
@@ -46,7 +46,8 @@ func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, u
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
||||
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
@@ -88,6 +89,13 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||
if pg.HasMember(modelName) {
|
||||
return pg.processes[modelName], true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
|
||||
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
@@ -27,6 +28,40 @@ const (
|
||||
|
||||
type proxyCtxKey string
|
||||
|
||||
type InflightCounter struct {
|
||||
mu sync.Mutex
|
||||
total int
|
||||
}
|
||||
|
||||
func newInflightCounter() *InflightCounter {
|
||||
return &InflightCounter{}
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Current() int {
|
||||
ic.mu.Lock()
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Increment() int {
|
||||
ic.mu.Lock()
|
||||
ic.total++
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Decrement() int {
|
||||
ic.mu.Lock()
|
||||
if ic.total > 0 {
|
||||
ic.total--
|
||||
}
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
@@ -42,22 +77,52 @@ type ProxyManager struct {
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
|
||||
inFlightCounter *InflightCounter
|
||||
|
||||
// shutdown signaling
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
|
||||
// version info
|
||||
buildDate string
|
||||
commit string
|
||||
version string
|
||||
|
||||
// peer proxy see: #296, #433
|
||||
peerProxy *PeerProxy
|
||||
}
|
||||
|
||||
func New(config config.Config) *ProxyManager {
|
||||
func New(proxyConfig config.Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
|
||||
if config.LogRequests {
|
||||
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
|
||||
switch proxyConfig.LogToStdout {
|
||||
case config.LogToStdoutNone:
|
||||
muxLogger = NewLogMonitorWriter(io.Discard)
|
||||
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||
case config.LogToStdoutBoth:
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||
case config.LogToStdoutUpstream:
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||
default:
|
||||
// same as config.LogToStdoutProxy
|
||||
// helpful because some old tests create a config.Config directly and it
|
||||
// may not have LogToStdout set explicitly
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||
}
|
||||
|
||||
if proxyConfig.LogRequests {
|
||||
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
||||
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
|
||||
case "debug":
|
||||
proxyLogger.SetLogLevel(LevelDebug)
|
||||
upstreamLogger.SetLogLevel(LevelDebug)
|
||||
@@ -75,60 +140,105 @@ func New(config config.Config) *ProxyManager {
|
||||
upstreamLogger.SetLogLevel(LevelInfo)
|
||||
}
|
||||
|
||||
// see: https://go.dev/src/time/format.go
|
||||
timeFormats := map[string]string{
|
||||
"ansic": time.ANSIC,
|
||||
"unixdate": time.UnixDate,
|
||||
"rubydate": time.RubyDate,
|
||||
"rfc822": time.RFC822,
|
||||
"rfc822z": time.RFC822Z,
|
||||
"rfc850": time.RFC850,
|
||||
"rfc1123": time.RFC1123,
|
||||
"rfc1123z": time.RFC1123Z,
|
||||
"rfc3339": time.RFC3339,
|
||||
"rfc3339nano": time.RFC3339Nano,
|
||||
"kitchen": time.Kitchen,
|
||||
"stamp": time.Stamp,
|
||||
"stampmilli": time.StampMilli,
|
||||
"stampmicro": time.StampMicro,
|
||||
"stampnano": time.StampNano,
|
||||
}
|
||||
|
||||
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
|
||||
proxyLogger.SetLogTimeFormat(timeFormat)
|
||||
upstreamLogger.SetLogTimeFormat(timeFormat)
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
|
||||
var maxMetrics int
|
||||
if config.MetricsMaxInMemory <= 0 {
|
||||
if proxyConfig.MetricsMaxInMemory <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
} else {
|
||||
maxMetrics = config.MetricsMaxInMemory
|
||||
maxMetrics = proxyConfig.MetricsMaxInMemory
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
||||
peerProxy = nil
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
config: proxyConfig,
|
||||
ginEngine: gin.New(),
|
||||
|
||||
proxyLogger: proxyLogger,
|
||||
muxLogger: stdoutLogger,
|
||||
muxLogger: muxLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
|
||||
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer),
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
|
||||
inFlightCounter: newInflightCounter(),
|
||||
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
|
||||
buildDate: "unknown",
|
||||
commit: "abcd1234",
|
||||
version: "0",
|
||||
|
||||
peerProxy: peerProxy,
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
for groupID := range config.Groups {
|
||||
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
||||
for groupID := range proxyConfig.Groups {
|
||||
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
||||
pm.processGroups[groupID] = processGroup
|
||||
}
|
||||
|
||||
pm.setupGinEngine()
|
||||
|
||||
// run any startup hooks
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
|
||||
// do it in the background, don't block startup -- not sure if good idea yet
|
||||
go func() {
|
||||
discardWriter := &DiscardWriter{}
|
||||
for _, realModelName := range config.Hooks.OnStartup.Preload {
|
||||
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
||||
|
||||
if !ok {
|
||||
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
|
||||
if err != nil {
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: false,
|
||||
})
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
|
||||
continue
|
||||
} else {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||
processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: true,
|
||||
})
|
||||
}
|
||||
@@ -203,35 +313,50 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
})
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
// Protected routes use pm.apiKeyAuth() middleware
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support anthropic count_tokens API (Also added in the above PR)
|
||||
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// Support embeddings and reranking
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /reranking endpoint + aliases
|
||||
pm.ginEngine.POST("/reranking", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /infill endpoint for code infilling
|
||||
pm.ginEngine.POST("/infill", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /completion endpoint
|
||||
pm.ginEngine.POST("/completion", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
// sd.cpp /sdapi/v1 endpoints
|
||||
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||
|
||||
// in proxymanager_loghandlers.go
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||
|
||||
/**
|
||||
* User Interface Endpoints
|
||||
@@ -243,9 +368,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
@@ -267,25 +392,35 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
if err != nil {
|
||||
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
||||
} else {
|
||||
// Serve files with compression support under /ui/*
|
||||
// This handler checks for pre-compressed .br and .gz files
|
||||
pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) {
|
||||
filepath := strings.TrimPrefix(c.Param("filepath"), "/")
|
||||
// Default to index.html for directory-like paths
|
||||
if filepath == "" {
|
||||
filepath = "index.html"
|
||||
}
|
||||
|
||||
// serve files that exist under /ui/*
|
||||
pm.ginEngine.StaticFS("/ui", reactFS)
|
||||
ServeCompressedFile(reactFS, c.Writer, c.Request, filepath)
|
||||
})
|
||||
|
||||
// server SPA for UI under /ui/*
|
||||
// Serve SPA for UI under /ui/* - fallback to index.html for client-side routing
|
||||
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
||||
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
file, err := reactFS.Open("index.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
// Check if this looks like a file request (has extension)
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") {
|
||||
// This was likely a file request that wasn't found
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
|
||||
|
||||
// Serve index.html for SPA routing
|
||||
ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -297,6 +432,14 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
gin.DisableConsoleColor()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) trackInflight() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Increment()})
|
||||
defer event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Decrement()})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler interface
|
||||
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
pm.ginEngine.ServeHTTP(w, r)
|
||||
@@ -343,16 +486,10 @@ func (pm *ProxyManager) Shutdown() {
|
||||
pm.shutdownCancel()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
||||
}
|
||||
|
||||
if processGroup.exclusive {
|
||||
@@ -364,20 +501,16 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
||||
}
|
||||
}
|
||||
|
||||
return processGroup, realModelName, nil
|
||||
return processGroup, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := make([]gin.H, 0, len(pm.config.Models))
|
||||
createdTime := time.Now().Unix()
|
||||
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
||||
record := gin.H{
|
||||
"id": id,
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
@@ -396,8 +529,41 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
data = append(data, record)
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
data = append(data, newRecord(id, modelConfig))
|
||||
|
||||
// Include aliases
|
||||
if pm.config.IncludeAliasesInList {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias, modelConfig))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
// add peer models
|
||||
for _, modelID := range peer.Models {
|
||||
// Skip unlisted models if not showing them
|
||||
record := newRecord(modelID, config.ModelConfig{
|
||||
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
||||
Metadata: map[string]any{
|
||||
"peerID": peerID,
|
||||
},
|
||||
})
|
||||
|
||||
data = append(data, record)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by the "id" key
|
||||
@@ -419,62 +585,61 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
// split the upstream path by / and search for the model name
|
||||
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
|
||||
if len(parts) == 0 {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
modelFound := false
|
||||
// findModelInPath searches for a valid model name in a path with slashes.
|
||||
// It iteratively builds up path segments until it finds a matching model.
|
||||
// Returns: (searchModelName, realModelName, remainingPath, found)
|
||||
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
|
||||
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
|
||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||
searchModelName := ""
|
||||
var modelName, remainingPath string
|
||||
|
||||
for i, part := range parts {
|
||||
if parts[i] == "" {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if searchModelName == "" {
|
||||
searchModelName = part
|
||||
} else {
|
||||
searchModelName = searchModelName + "/" + parts[i]
|
||||
searchModelName = searchModelName + "/" + part
|
||||
}
|
||||
|
||||
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
modelName = real
|
||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||
modelFound = true
|
||||
|
||||
// Check if this is exactly a model name with no additional path
|
||||
// and doesn't end with a trailing slash
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
// Build new URL with query parameters preserved
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Use 308 for non-GET/HEAD requests to preserve method
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
break
|
||||
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||
|
||||
if !modelFound {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
||||
// This ensures relative URLs in upstream responses resolve correctly and
|
||||
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
||||
// HTTP method (301 would downgrade to GET).
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -486,21 +651,21 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
|
||||
// attempt to record metrics if it is a POST request
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
@@ -513,41 +678,101 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
// Look for a matching local model first
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// issue #453 set/override parameters in the JSON body
|
||||
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
||||
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
||||
for _, key := range setParamsByIDKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
|
||||
// issue #453 apply filters for peer requests
|
||||
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||
|
||||
// Apply stripParams - remove specified parameters from request
|
||||
stripParams := peerFilters.SanitizedStripParams()
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Apply setParams - set/override specified parameters in request
|
||||
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
@@ -560,19 +785,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -592,9 +817,29 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
// Look for a matching local model first, then check peers
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
var useModelName string
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
useModelName = pm.config.Models[modelID].UseModelName
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -610,8 +855,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
} else {
|
||||
@@ -681,9 +924,46 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
||||
requestedModel := c.Query("model")
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter")
|
||||
return
|
||||
}
|
||||
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
var modelID string
|
||||
|
||||
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
||||
processGroup, err := pm.swapProcessGroup(realModelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
modelID = realModelID
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
modelID = requestedModel
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -698,6 +978,67 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
||||
}
|
||||
}
|
||||
|
||||
// apiKeyAuth returns a middleware that validates API keys if configured.
|
||||
// Returns a pass-through handler if no API keys are configured.
|
||||
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
||||
if len(pm.config.RequiredAPIKeys) == 0 {
|
||||
return func(c *gin.Context) { c.Next() }
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
xApiKey := c.GetHeader("x-api-key")
|
||||
|
||||
var bearerKey string
|
||||
var basicKey string
|
||||
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||
} else if strings.HasPrefix(auth, "Basic ") {
|
||||
// Basic Auth: base64(username:password), password is the API key
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) == 2 {
|
||||
basicKey = parts[1] // password is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use first key found: Basic, then Bearer, then x-api-key
|
||||
var providedKey string
|
||||
if basicKey != "" {
|
||||
providedKey = basicKey
|
||||
} else if bearerKey != "" {
|
||||
providedKey = bearerKey
|
||||
} else {
|
||||
providedKey = xApiKey
|
||||
}
|
||||
|
||||
// Validate key
|
||||
valid := false
|
||||
for _, key := range pm.config.RequiredAPIKeys {
|
||||
if providedKey == key {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Strip auth headers to prevent leakage to upstream
|
||||
c.Request.Header.Del("Authorization")
|
||||
c.Request.Header.Del("x-api-key")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.String(http.StatusOK, "OK")
|
||||
@@ -711,8 +1052,13 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"cmd": process.config.Cmd,
|
||||
"proxy": process.config.Proxy,
|
||||
"ttl": process.config.UnloadAfter,
|
||||
"name": process.config.Name,
|
||||
"description": process.config.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -734,3 +1080,11 @@ func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
pm.buildDate = buildDate
|
||||
pm.commit = commit
|
||||
pm.version = version
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,21 +14,26 @@ import (
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
// Protected with API key authentication
|
||||
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
apiGroup.GET("/version", pm.apiGetVersion)
|
||||
apiGroup.GET("/captures/:id", pm.apiGetCapture)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,9 +84,22 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
||||
Description: pm.config.Models[modelID].Description,
|
||||
State: state,
|
||||
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||
Aliases: pm.config.Models[modelID].Aliases,
|
||||
})
|
||||
}
|
||||
|
||||
// Iterate over the peer models
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
PeerID: peerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -90,6 +109,7 @@ const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
msgTypeInFlight messageType = "inflight"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
@@ -149,6 +169,18 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
sendInFlight := func(total int) {
|
||||
jsonData, err := json.Marshal(gin.H{"total": total})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
@@ -176,11 +208,19 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
sendMetrics([]TokenMetrics{e.Metrics})
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send in-flight request stats related to token stats "Waiting: N" count.
|
||||
*/
|
||||
defer event.On(func(e InFlightRequestsEvent) {
|
||||
sendInFlight(e.Total)
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||
sendInFlight(pm.inFlightCounter.Current())
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -227,3 +267,28 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, map[string]string{
|
||||
"version": pm.version,
|
||||
"commit": pm.commit,
|
||||
"build_date": pm.buildDate,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
|
||||
return
|
||||
}
|
||||
|
||||
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, capture)
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
// prevent nginx from buffering streamed logs
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
@@ -83,18 +83,25 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
|
||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||
var logger *LogMonitor
|
||||
|
||||
if logMonitorId == "" {
|
||||
switch logMonitorId {
|
||||
case "":
|
||||
// maintain the default
|
||||
logger = pm.muxLogger
|
||||
} else if logMonitorId == "proxy" {
|
||||
logger = pm.proxyLogger
|
||||
} else if logMonitorId == "upstream" {
|
||||
logger = pm.upstreamLogger
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
||||
}
|
||||
return pm.muxLogger, nil
|
||||
case "proxy":
|
||||
return pm.proxyLogger, nil
|
||||
case "upstream":
|
||||
return pm.upstreamLogger, nil
|
||||
default:
|
||||
// search for a models specific logger using findModelInPath
|
||||
// to handle model names with slashes (e.g., "author/model")
|
||||
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
|
||||
for _, group := range pm.processGroups {
|
||||
if process, found := group.GetMember(name); found {
|
||||
return process.Logger(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
||||
return r.closeChannel
|
||||
}
|
||||
|
||||
func (r *TestResponseRecorder) closeClient() {
|
||||
r.closeChannel <- true
|
||||
}
|
||||
|
||||
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||
return &TestResponseRecorder{
|
||||
httptest.NewRecorder(),
|
||||
@@ -223,17 +220,23 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||
model2Config.Description = " "
|
||||
|
||||
config := config.Config{
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
Peers: map[string]config.PeerConfig{
|
||||
"peer1": {
|
||||
Proxy: "http://peer1:8080",
|
||||
Models: []string{"peer-model-a", "peer-model-b"},
|
||||
},
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
proxy := New(cfg)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
@@ -258,14 +261,16 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Check the number of models returned
|
||||
assert.Len(t, response.Data, 3)
|
||||
// Check the number of models returned (3 local + 2 peer models)
|
||||
assert.Len(t, response.Data, 5)
|
||||
|
||||
// Check the details of each model
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
"peer-model-a": {},
|
||||
"peer-model-b": {},
|
||||
}
|
||||
|
||||
// make all models
|
||||
@@ -296,6 +301,19 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
description, ok := model["description"].(string)
|
||||
assert.True(t, ok, "description should be a string")
|
||||
assert.Equal(t, "Model 1 description is used for testing", description)
|
||||
} else if modelID == "peer-model-a" || modelID == "peer-model-b" {
|
||||
// Peer models should have meta.llamaswap.peerID
|
||||
meta, exists := model["meta"]
|
||||
assert.True(t, exists, "peer model should have meta field")
|
||||
metaMap, ok := meta.(map[string]interface{})
|
||||
assert.True(t, ok, "meta should be a map")
|
||||
llamaswap, exists := metaMap["llamaswap"]
|
||||
assert.True(t, exists, "meta should have llamaswap field")
|
||||
llamaswapMap, ok := llamaswap.(map[string]interface{})
|
||||
assert.True(t, ok, "llamaswap should be a map")
|
||||
peerID, exists := llamaswapMap["peerID"]
|
||||
assert.True(t, exists, "llamaswap should have peerID field")
|
||||
assert.Equal(t, "peer1", peerID)
|
||||
} else {
|
||||
_, exists := model["name"]
|
||||
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
||||
@@ -437,7 +455,75 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
|
||||
// Configure alias
|
||||
config := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
IncludeAliasesInList: true,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": func() config.ModelConfig {
|
||||
mc := getTestSimpleResponderConfig("model1")
|
||||
mc.Name = "Model 1"
|
||||
mc.Aliases = []string{"alias1"}
|
||||
return mc
|
||||
}(),
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Request models list
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// We expect both base id and alias
|
||||
var model1Data, alias1Data map[string]any
|
||||
for _, model := range response.Data {
|
||||
if model["id"] == "model1" {
|
||||
model1Data = model
|
||||
} else if model["id"] == "alias1" {
|
||||
alias1Data = model
|
||||
}
|
||||
}
|
||||
|
||||
// Verify model1 has name
|
||||
assert.NotNil(t, model1Data)
|
||||
_, exists := model1Data["name"]
|
||||
if !assert.True(t, exists, "model1 should have name key") {
|
||||
t.FailNow()
|
||||
}
|
||||
name1, ok := model1Data["name"].(string)
|
||||
assert.True(t, ok, "name1 should be a string")
|
||||
|
||||
// Verify alias1 has name
|
||||
assert.NotNil(t, alias1Data)
|
||||
_, exists = alias1Data["name"]
|
||||
if !assert.True(t, exists, "alias1 should have name key") {
|
||||
t.FailNow()
|
||||
}
|
||||
name2, ok := alias1Data["name"].(string)
|
||||
assert.True(t, ok, "name2 should be a string")
|
||||
|
||||
// Name keys should match
|
||||
assert.Equal(t, name1, name2)
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
model1Config.Proxy = "http://localhost:10001/"
|
||||
@@ -586,8 +672,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
Running []struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Cmd string `json:"cmd"`
|
||||
Proxy string `json:"proxy"`
|
||||
TTL int `json:"ttl"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"running"`
|
||||
}
|
||||
|
||||
@@ -635,6 +726,11 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
|
||||
// Verify extended fields are present
|
||||
assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated")
|
||||
assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated")
|
||||
assert.Equal(t, -1, response.Running[0].TTL, "ttl should default to -1 (use globalTTL)")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -754,6 +850,43 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioVoicesGETHandler(t *testing.T) {
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("successful GET with model query param", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices?model=model1", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "voice1")
|
||||
})
|
||||
|
||||
t.Run("missing model query param returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices?model=nonexistent", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -880,7 +1013,9 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig("model1")
|
||||
modelConfig.Filters = config.ModelFilters{
|
||||
StripParams: "temperature, model, stream",
|
||||
Filters: config.Filters{
|
||||
StripParams: "temperature, model, stream",
|
||||
},
|
||||
}
|
||||
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
@@ -911,6 +1046,61 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
// t.Logf("%v", response)
|
||||
}
|
||||
|
||||
func TestProxyManager_FiltersSetParamsByID(t *testing.T) {
|
||||
// no explicit aliases — setParamsByID keys are auto-registered as aliases
|
||||
configStr := strings.Replace(`
|
||||
logLevel: error
|
||||
models:
|
||||
model1:
|
||||
cmd: 'SRPATH --port ${PORT} --silent --respond model1'
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
filters:
|
||||
setParams:
|
||||
reasoning_effort: medium
|
||||
setParamsByID:
|
||||
"${MODEL_ID}:high":
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
reasoning_effort: low
|
||||
`, "SRPATH", simpleResponderPath, -1)
|
||||
|
||||
cfg, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
if !assert.NoError(t, err, "invalid test configuration") {
|
||||
return
|
||||
}
|
||||
|
||||
proxy := New(cfg)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []struct {
|
||||
requestedModel string
|
||||
wantEffort string
|
||||
}{
|
||||
// setParams applies, no setParamsByID match
|
||||
{requestedModel: "model1", wantEffort: "medium"},
|
||||
// setParamsByID overrides setParams
|
||||
{requestedModel: "model1:high", wantEffort: "high"},
|
||||
{requestedModel: "model1:low", wantEffort: "low"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.requestedModel, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":%q}`, tt.requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
requestBody, _ := response["request_body"].(string)
|
||||
gotEffort := gjson.Get(requestBody, "reasoning_effort").String()
|
||||
assert.Equal(t, tt.wantEffort, gotEffort, "reasoning_effort mismatch for model %s", tt.requestedModel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -1014,7 +1204,8 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"author/model": getTestSimpleResponderConfig("author/model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
@@ -1027,6 +1218,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
"/logs/stream",
|
||||
"/logs/stream/proxy",
|
||||
"/logs/stream/upstream",
|
||||
"/logs/stream/author/model",
|
||||
}
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
@@ -1083,3 +1275,466 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin
|
||||
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
||||
assert.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
||||
}
|
||||
|
||||
func TestProxyManager_ApiGetVersion(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
// Version test map
|
||||
versionTest := map[string]string{
|
||||
"build_date": "1970-01-01T00:00:00Z",
|
||||
"commit": "cc915ddb6f04a42d9cd1f524e1d46ec6ed069fdc",
|
||||
"version": "v001",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
proxy.SetVersion(versionTest["build_date"], versionTest["commit"], versionTest["version"])
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/version", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Ensure json response
|
||||
assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type"))
|
||||
|
||||
// Check for attributes
|
||||
response := map[string]string{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
for key, value := range versionTest {
|
||||
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_APIKeyAuth(t *testing.T) {
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
t.Run("valid key in x-api-key header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("valid key in Authorization Bearer header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer valid-key-2")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("both headers with matching keys", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
req.Header.Set("Authorization", "Bearer valid-key-1")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid key returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "invalid-key")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||
})
|
||||
|
||||
t.Run("missing key returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
})
|
||||
|
||||
t.Run("valid key in Basic Auth header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
// Basic Auth: base64("anyuser:valid-key-1")
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||
})
|
||||
|
||||
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
|
||||
// Config without RequiredAPIKeys - auth should be disabled
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
t.Run("requests pass without API key when not configured", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration
|
||||
// in proxyInferenceHandler for issue #433
|
||||
func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) {
|
||||
t.Run("requests to peer models are proxied", func(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
// Create config with peers but no local model for "peer-model"
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "from-peer")
|
||||
})
|
||||
|
||||
t.Run("local models take precedence over peer models", func(t *testing.T) {
|
||||
// Create a test server to act as the peer - should NOT be called
|
||||
peerCalled := false
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
peerCalled = true
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"from-peer"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
// Create config where "shared-model" exists both locally and on peer
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- shared-model
|
||||
models:
|
||||
shared-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-response
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"shared-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "local-response")
|
||||
assert.False(t, peerCalled, "peer should not be called when local model exists")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns error", func(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"unknown-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
||||
})
|
||||
|
||||
t.Run("peer API key is injected into request", func(t *testing.T) {
|
||||
var receivedAuthHeader string
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"ok"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
apiKey: secret-peer-key
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader)
|
||||
})
|
||||
|
||||
t.Run("no peers configured - unknown model returns error", func(t *testing.T) {
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"local-model": getTestSimpleResponderConfig("local-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
// peerProxy exists but has no peer models configured
|
||||
assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model"))
|
||||
|
||||
reqBody := `{"model":"unknown-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
||||
})
|
||||
|
||||
t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) {
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("data: test\n\n"))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_SdApiTxt2ImgRouting(t *testing.T) {
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"sd-model": getTestSimpleResponderConfig("sd-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("successful txt2img with model", func(t *testing.T) {
|
||||
reqBody := `{"model":"sd-model","prompt":"a cat"}`
|
||||
req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "sd-model")
|
||||
})
|
||||
|
||||
t.Run("successful img2img with model", func(t *testing.T) {
|
||||
reqBody := `{"model":"sd-model","prompt":"a cat","init_images":[]}`
|
||||
req := httptest.NewRequest("POST", "/sdapi/v1/img2img", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "sd-model")
|
||||
})
|
||||
|
||||
t.Run("missing model returns 400", func(t *testing.T) {
|
||||
reqBody := `{"prompt":"a cat"}`
|
||||
req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "missing or invalid 'model' key")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_SdApiGetLoras(t *testing.T) {
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"sd-model": getTestSimpleResponderConfig("sd-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("successful GET loras with model query param", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=sd-model", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("missing model query param returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/sdapi/v1/loras", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=nonexistent", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// selectEncoding chooses the best encoding based on Accept-Encoding header
|
||||
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
|
||||
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||
if acceptEncoding == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "br" {
|
||||
return "br", ".br"
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "gzip" {
|
||||
return "gzip", ".gz"
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// ServeCompressedFile serves a file with compression support.
|
||||
// It checks for pre-compressed versions and serves them with proper headers.
|
||||
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
|
||||
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
|
||||
|
||||
// Try to serve compressed version if client supports it
|
||||
if encoding != "" {
|
||||
if cf, err := fs.Open(name + ext); err == nil {
|
||||
defer cf.Close()
|
||||
|
||||
// Verify it's a regular file (not a directory)
|
||||
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||
// Set the content encoding header
|
||||
w.Header().Set("Content-Encoding", encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
|
||||
// Get original file info for content type detection
|
||||
origFile, err := fs.Open(name)
|
||||
if err == nil {
|
||||
origFile.Close()
|
||||
}
|
||||
|
||||
// Serve the compressed file
|
||||
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serving the uncompressed file
|
||||
file, err := fs.Open(name)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
http.Error(w, "is a directory", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServeCompressedFile_Brotli(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with brotli")
|
||||
brContent := []byte("fake-brotli-compressed-data")
|
||||
|
||||
// Create a test filesystem
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: brContent, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check that brotli is used (preferred over gzip)
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
|
||||
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, brContent) {
|
||||
t.Errorf("Expected brotli content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_Gzip(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with gzip")
|
||||
gzContent := []byte("fake-gzip-compressed-data")
|
||||
|
||||
// Create a test filesystem without brotli
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, gzContent) {
|
||||
t.Errorf("Expected gzip content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is uncompressed test content")
|
||||
|
||||
// Create a test filesystem without compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should not have Content-Encoding header since we're serving uncompressed
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content")
|
||||
|
||||
// Create a test filesystem with compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
// No Accept-Encoding header
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should serve uncompressed content
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NotFound(t *testing.T) {
|
||||
mapFS := fstest.MapFS{}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "nonexistent.js")
|
||||
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectEncoding(t *testing.T) {
|
||||
tests := []struct {
|
||||
acceptEncoding string
|
||||
wantEncoding string
|
||||
wantExt string
|
||||
}{
|
||||
{"br, gzip", "br", ".br"},
|
||||
{"gzip, deflate", "gzip", ".gz"},
|
||||
{"gzip", "gzip", ".gz"},
|
||||
{"br", "br", ".br"},
|
||||
{"", "", ""},
|
||||
{"deflate", "", ""},
|
||||
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
|
||||
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
|
||||
{"browser", "", ""},
|
||||
{"compress, deflate", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
|
||||
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
|
||||
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
|
||||
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual pre-compressed files from ui_dist
|
||||
func TestServeCompressedFile_RealFiles(t *testing.T) {
|
||||
// Check if ui_dist exists
|
||||
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
|
||||
t.Skip("ui_dist not found, skipping real file test")
|
||||
}
|
||||
|
||||
// Find a .js or .css file that has compressed versions
|
||||
entries, err := os.ReadDir("./ui_dist/assets")
|
||||
if err != nil {
|
||||
t.Skipf("Could not read ui_dist/assets: %v", err)
|
||||
}
|
||||
|
||||
var testFile string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
|
||||
// Check if compressed versions exist
|
||||
base := strings.TrimSuffix(name, ".js")
|
||||
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
|
||||
testFile = "assets/" + name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if testFile == "" {
|
||||
t.Skip("No suitable test file found with compressed versions")
|
||||
}
|
||||
|
||||
fs := http.FS(os.DirFS("./ui_dist"))
|
||||
|
||||
// Test brotli
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
})
|
||||
|
||||
// Test gzip
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
// Verify it's valid gzip
|
||||
reader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid gzip content: %v", err)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Just read to verify it's valid
|
||||
_, err = io.Copy(io.Discard, reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decompress gzip: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
node_modules
|
||||
.vite
|
||||
@@ -0,0 +1 @@
|
||||
legacy-peer-deps=true
|
||||
@@ -10,8 +10,8 @@
|
||||
<link rel="manifest" href="/site.webmanifest" />
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body >
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.ts"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"name": "ui-svelte",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"start": "vite",
|
||||
"build": "vite build --emptyOutDir",
|
||||
"preview": "vite preview",
|
||||
"check": "svelte-check --tsconfig ./tsconfig.json",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@sveltejs/vite-plugin-svelte": "^7.0.0",
|
||||
"@tailwindcss/vite": "^4.1.8",
|
||||
"@tsconfig/svelte": "^5.0.4",
|
||||
"@types/hast": "^3.0.4",
|
||||
"@types/node": "^25.1.0",
|
||||
"svelte": "^5.46.4",
|
||||
"svelte-check": "^4.1.4",
|
||||
"tailwindcss": "^4.1.8",
|
||||
"typescript": "~5.8.3",
|
||||
"vite": "^8.0.0",
|
||||
"vite-plugin-compression2": "^2.5.1",
|
||||
"vitest": "^4.1.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.28",
|
||||
"lucide-svelte": "^0.563.0",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"remark-parse": "^11.0.0",
|
||||
"remark-rehype": "^11.1.2",
|
||||
"svelte-spa-router": "^4.0.1",
|
||||
"unified": "^11.0.5",
|
||||
"unist-util-visit": "^5.1.0"
|
||||
}
|
||||
}
|
||||
|
Before Width: | Height: | Size: 5.9 KiB After Width: | Height: | Size: 5.9 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 38 KiB After Width: | Height: | Size: 38 KiB |
|
Before Width: | Height: | Size: 6.5 KiB After Width: | Height: | Size: 6.5 KiB |
|
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 28 KiB |
@@ -0,0 +1,58 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import Router from "svelte-spa-router";
|
||||
import Header from "./components/Header.svelte";
|
||||
import LogViewer from "./routes/LogViewer.svelte";
|
||||
import Models from "./routes/Models.svelte";
|
||||
import Activity from "./routes/Activity.svelte";
|
||||
import Playground from "./routes/Playground.svelte";
|
||||
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
||||
import { enableAPIEvents } from "./stores/api";
|
||||
import { initScreenWidth, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
||||
import { currentRoute } from "./stores/route";
|
||||
|
||||
const routes = {
|
||||
"/": PlaygroundStub,
|
||||
"/models": Models,
|
||||
"/logs": LogViewer,
|
||||
"/activity": Activity,
|
||||
"*": PlaygroundStub,
|
||||
};
|
||||
|
||||
function handleRouteLoaded(event: { detail: { route: string | RegExp } }) {
|
||||
const route = event.detail.route;
|
||||
currentRoute.set(typeof route === "string" ? route : "/");
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
document.documentElement.setAttribute("data-theme", $isDarkMode ? "dark" : "light");
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
const icon = $connectionState === "connecting" ? "\u{1F7E1}" : $connectionState === "connected" ? "\u{1F7E2}" : "\u{1F534}";
|
||||
document.title = `${icon} ${$appTitle}`;
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
const cleanupScreenWidth = initScreenWidth();
|
||||
enableAPIEvents(true);
|
||||
|
||||
return () => {
|
||||
cleanupScreenWidth();
|
||||
enableAPIEvents(false);
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-screen">
|
||||
<Header />
|
||||
|
||||
<main class="flex-1 overflow-auto p-4">
|
||||
<div class="h-full" class:hidden={$currentRoute !== "/"}>
|
||||
<Playground />
|
||||
</div>
|
||||
<div class="h-full" class:hidden={$currentRoute === "/"}>
|
||||
<Router {routes} on:routeLoaded={handleRouteLoaded} />
|
||||
</div>
|
||||
</main>
|
||||
</div>
|
||||
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 4.0 KiB After Width: | Height: | Size: 4.0 KiB |
@@ -0,0 +1,452 @@
|
||||
<script lang="ts">
|
||||
import type { ReqRespCapture } from "../lib/types";
|
||||
|
||||
interface Props {
|
||||
capture: ReqRespCapture | null;
|
||||
open: boolean;
|
||||
onclose: () => void;
|
||||
}
|
||||
|
||||
let { capture, open, onclose }: Props = $props();
|
||||
|
||||
let dialogEl: HTMLDialogElement | undefined = $state();
|
||||
|
||||
type BodyTab = "raw" | "pretty" | "chat";
|
||||
let reqBodyTab: BodyTab = $state("pretty");
|
||||
let respBodyTab: BodyTab = $state("pretty");
|
||||
let copiedReq = $state(false);
|
||||
let copiedResp = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
if (open && dialogEl) {
|
||||
dialogEl.showModal();
|
||||
} else if (!open && dialogEl) {
|
||||
dialogEl.close();
|
||||
}
|
||||
});
|
||||
|
||||
// Reset tabs when capture changes
|
||||
$effect(() => {
|
||||
if (capture) {
|
||||
const reqCt = getContentType(capture.req_headers);
|
||||
const respCt = getContentType(capture.resp_headers);
|
||||
reqBodyTab = reqCt.includes("json") ? "pretty" : "raw";
|
||||
respBodyTab = respCt.includes("text/event-stream")
|
||||
? "chat"
|
||||
: respCt.includes("json")
|
||||
? "pretty"
|
||||
: "raw";
|
||||
}
|
||||
});
|
||||
|
||||
function handleDialogClose() {
|
||||
onclose();
|
||||
}
|
||||
|
||||
function decodeBody(body: string | null | undefined): string {
|
||||
if (!body) return "";
|
||||
try {
|
||||
const binary = atob(body);
|
||||
const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
|
||||
return new TextDecoder().decode(bytes);
|
||||
} catch {
|
||||
return body;
|
||||
}
|
||||
}
|
||||
|
||||
function formatJson(str: string): string {
|
||||
try {
|
||||
const parsed = JSON.parse(str);
|
||||
return JSON.stringify(parsed, null, 2);
|
||||
} catch {
|
||||
return str;
|
||||
}
|
||||
}
|
||||
|
||||
function getContentType(
|
||||
headers: Record<string, string> | null | undefined,
|
||||
): string {
|
||||
if (!headers) return "";
|
||||
const ct = headers["Content-Type"] || headers["content-type"] || "";
|
||||
return ct.toLowerCase();
|
||||
}
|
||||
|
||||
function isImageContentType(contentType: string): boolean {
|
||||
return contentType.startsWith("image/");
|
||||
}
|
||||
|
||||
function isTextContentType(contentType: string): boolean {
|
||||
return (
|
||||
contentType.startsWith("text/") ||
|
||||
contentType.includes("application/json") ||
|
||||
contentType.includes("application/xml") ||
|
||||
contentType.includes("application/javascript")
|
||||
);
|
||||
}
|
||||
|
||||
function getImageDataUrl(body: string, contentType: string): string {
|
||||
const mimeType = contentType.split(";")[0].trim();
|
||||
return `data:${mimeType};base64,${body}`;
|
||||
}
|
||||
|
||||
interface SSEChat {
|
||||
reasoning: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
function parseSSEChat(text: string): SSEChat {
|
||||
const result: SSEChat = { reasoning: "", content: "" };
|
||||
for (const line of text.split("\n")) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || !trimmed.startsWith("data: ")) continue;
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
const delta = parsed.choices?.[0]?.delta;
|
||||
if (delta?.content) result.content += delta.content;
|
||||
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
|
||||
} catch {
|
||||
// skip unparseable lines
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
async function copyToClipboard(text: string, type: "req" | "resp") {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
if (type === "req") {
|
||||
copiedReq = true;
|
||||
setTimeout(() => (copiedReq = false), 1500);
|
||||
} else {
|
||||
copiedResp = true;
|
||||
setTimeout(() => (copiedResp = false), 1500);
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
function getCopyText(): string {
|
||||
if (respBodyTab === "chat") {
|
||||
let text = "";
|
||||
if (sseChat.reasoning) text += sseChat.reasoning + "\n\n";
|
||||
text += sseChat.content;
|
||||
return text;
|
||||
}
|
||||
return displayedResponseBody;
|
||||
}
|
||||
|
||||
// Request body derivations
|
||||
let requestContentType = $derived(
|
||||
capture ? getContentType(capture.req_headers) : "",
|
||||
);
|
||||
let isRequestJson = $derived(requestContentType.includes("json"));
|
||||
|
||||
let requestBodyRaw = $derived.by(() => {
|
||||
if (!capture) return "";
|
||||
return decodeBody(capture.req_body);
|
||||
});
|
||||
|
||||
let requestBodyPretty = $derived.by(() => {
|
||||
if (!isRequestJson) return requestBodyRaw;
|
||||
return formatJson(requestBodyRaw);
|
||||
});
|
||||
|
||||
let displayedRequestBody = $derived(
|
||||
reqBodyTab === "pretty" ? requestBodyPretty : requestBodyRaw,
|
||||
);
|
||||
|
||||
// Response body derivations
|
||||
let responseContentType = $derived(
|
||||
capture ? getContentType(capture.resp_headers) : "",
|
||||
);
|
||||
let isResponseImage = $derived(isImageContentType(responseContentType));
|
||||
let isResponseText = $derived(isTextContentType(responseContentType));
|
||||
let isResponseJson = $derived(responseContentType.includes("json"));
|
||||
let isSSE = $derived(responseContentType.includes("text/event-stream"));
|
||||
|
||||
let responseBodyRaw = $derived.by(() => {
|
||||
if (!capture) return "";
|
||||
return decodeBody(capture.resp_body);
|
||||
});
|
||||
|
||||
let responseBodyPretty = $derived.by(() => {
|
||||
if (!isResponseJson) return responseBodyRaw;
|
||||
return formatJson(responseBodyRaw);
|
||||
});
|
||||
|
||||
let sseChat = $derived.by(() => {
|
||||
if (!isSSE || !responseBodyRaw)
|
||||
return { reasoning: "", content: "" } as SSEChat;
|
||||
return parseSSEChat(responseBodyRaw);
|
||||
});
|
||||
|
||||
let displayedResponseBody = $derived.by(() => {
|
||||
if (respBodyTab === "pretty") return responseBodyPretty;
|
||||
return responseBodyRaw;
|
||||
});
|
||||
</script>
|
||||
|
||||
<dialog
|
||||
bind:this={dialogEl}
|
||||
onclose={handleDialogClose}
|
||||
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-4xl w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
|
||||
>
|
||||
{#if capture}
|
||||
<div class="flex flex-col max-h-[90vh]">
|
||||
<div
|
||||
class="flex justify-between items-center p-4 border-b border-card-border"
|
||||
>
|
||||
<h2 class="text-xl font-bold pb-0">Capture #{capture.id + 1}{#if capture.req_path} <span class="text-base font-mono font-normal text-txtsecondary">{capture.req_path}</span>{/if}</h2>
|
||||
<button
|
||||
onclick={() => dialogEl?.close()}
|
||||
class="text-txtsecondary hover:text-txtmain text-2xl leading-none"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="overflow-y-auto flex-1 p-4 space-y-4">
|
||||
<!-- Request Headers -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Request Headers
|
||||
</summary>
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
|
||||
>
|
||||
<table class="w-full text-sm">
|
||||
<tbody>
|
||||
{#each Object.entries(capture.req_headers || {}) as [key, value]}
|
||||
<tr class="border-b border-card-border-inner last:border-0">
|
||||
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
|
||||
>{key}</td
|
||||
>
|
||||
<td class="px-3 py-1 font-mono break-all">{value}</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Request Body -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Request Body
|
||||
</summary>
|
||||
{#if requestBodyRaw}
|
||||
<div class="mt-2 flex items-center justify-between">
|
||||
<div class="flex gap-1">
|
||||
{#if isRequestJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={reqBodyTab === "pretty"}
|
||||
onclick={() => (reqBodyTab = "pretty")}>Pretty</button
|
||||
>
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={reqBodyTab === "raw"}
|
||||
onclick={() => (reqBodyTab = "raw")}>Raw</button
|
||||
>
|
||||
{/if}
|
||||
</div>
|
||||
<button
|
||||
class="tab-btn"
|
||||
onclick={() =>
|
||||
copyToClipboard(displayedRequestBody, "req")}
|
||||
>
|
||||
{#if copiedReq}
|
||||
Copied!
|
||||
{:else}
|
||||
Copy
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre
|
||||
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedRequestBody}</pre>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre class="p-3 text-sm font-mono whitespace-pre-wrap break-all"
|
||||
>(empty)</pre
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</details>
|
||||
|
||||
<!-- Response Headers -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Response Headers
|
||||
</summary>
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
|
||||
>
|
||||
<table class="w-full text-sm">
|
||||
<tbody>
|
||||
{#each Object.entries(capture.resp_headers || {}) as [key, value]}
|
||||
<tr class="border-b border-card-border-inner last:border-0">
|
||||
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
|
||||
>{key}</td
|
||||
>
|
||||
<td class="px-3 py-1 font-mono break-all">{value}</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Response Body -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Response Body
|
||||
</summary>
|
||||
{#if isResponseImage && capture.resp_body}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<div class="p-3 flex justify-center">
|
||||
<img
|
||||
src={getImageDataUrl(capture.resp_body, responseContentType)}
|
||||
alt="Response"
|
||||
class="max-w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{:else if isSSE || isResponseText}
|
||||
<div class="mt-2 flex items-center justify-between">
|
||||
<div class="flex gap-1">
|
||||
{#if isSSE}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "chat"}
|
||||
onclick={() => (respBodyTab = "chat")}>Chat</button
|
||||
>
|
||||
{/if}
|
||||
{#if isResponseJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "pretty"}
|
||||
onclick={() => (respBodyTab = "pretty")}>Pretty</button
|
||||
>
|
||||
{/if}
|
||||
{#if isSSE || isResponseJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "raw"}
|
||||
onclick={() => (respBodyTab = "raw")}>Raw</button
|
||||
>
|
||||
{/if}
|
||||
</div>
|
||||
<button
|
||||
class="tab-btn"
|
||||
onclick={() => copyToClipboard(getCopyText(), "resp")}
|
||||
>
|
||||
{#if copiedResp}
|
||||
Copied!
|
||||
{:else}
|
||||
Copy
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
{#if respBodyTab === "chat"}
|
||||
<div class="p-3 text-sm space-y-3">
|
||||
{#if sseChat.reasoning}
|
||||
<div>
|
||||
<div
|
||||
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
|
||||
>
|
||||
Reasoning
|
||||
</div>
|
||||
<pre
|
||||
class="font-mono whitespace-pre-wrap break-all text-txtsecondary">{sseChat.reasoning}</pre>
|
||||
</div>
|
||||
{/if}
|
||||
{#if sseChat.content}
|
||||
<div>
|
||||
{#if sseChat.reasoning}
|
||||
<div
|
||||
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
|
||||
>
|
||||
Response
|
||||
</div>
|
||||
{/if}
|
||||
<pre
|
||||
class="font-mono whitespace-pre-wrap break-all">{sseChat.content}</pre>
|
||||
</div>
|
||||
{/if}
|
||||
{#if !sseChat.reasoning && !sseChat.content}
|
||||
<pre class="font-mono">(empty)</pre>
|
||||
{/if}
|
||||
</div>
|
||||
{:else}
|
||||
<pre
|
||||
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedResponseBody || "(empty)"}</pre>
|
||||
{/if}
|
||||
</div>
|
||||
{:else if responseBodyRaw}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<div class="p-3 text-sm text-txtsecondary italic">
|
||||
(binary data - {responseContentType || "unknown content type"})
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre class="p-3 text-sm font-mono">(empty)</pre>
|
||||
</div>
|
||||
{/if}
|
||||
</details>
|
||||
</div>
|
||||
|
||||
<div class="p-4 border-t border-card-border flex justify-end">
|
||||
<button onclick={() => dialogEl?.close()} class="btn"> Close </button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</dialog>
|
||||
|
||||
<style>
|
||||
.tab-btn {
|
||||
padding: 2px 10px;
|
||||
font-size: 0.75rem;
|
||||
border-radius: 4px;
|
||||
color: var(--color-txtsecondary);
|
||||
cursor: pointer;
|
||||
border: 1px solid transparent;
|
||||
background: transparent;
|
||||
transition: all 0.15s;
|
||||
}
|
||||
.tab-btn:hover {
|
||||
color: var(--color-txtmain);
|
||||
background: var(--color-secondary);
|
||||
}
|
||||
.tab-btn-active {
|
||||
color: var(--color-primary);
|
||||
background: color-mix(in srgb, var(--color-primary) 12%, transparent);
|
||||
border-color: color-mix(in srgb, var(--color-primary) 25%, transparent);
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,24 @@
|
||||
<script lang="ts">
|
||||
import { connectionState } from "../stores/theme";
|
||||
import { versionInfo } from "../stores/api";
|
||||
|
||||
let eventStatusColor = $derived.by(() => {
|
||||
switch ($connectionState) {
|
||||
case "connected":
|
||||
return "bg-emerald-500";
|
||||
case "connecting":
|
||||
return "bg-amber-500";
|
||||
case "disconnected":
|
||||
default:
|
||||
return "bg-red-500";
|
||||
}
|
||||
});
|
||||
|
||||
let tooltipText = $derived(
|
||||
`Event Stream: ${$connectionState ?? "unknown"}\nAPI Version: ${$versionInfo?.version ?? "unknown"}\nCommit Hash: ${$versionInfo?.commit?.substring(0, 7) ?? "unknown"}\nBuild Date: ${$versionInfo?.build_date ?? "unknown"}`
|
||||
);
|
||||
</script>
|
||||
|
||||
<div class="flex items-center" title={tooltipText}>
|
||||
<span class="inline-block w-3 h-3 rounded-full {eventStatusColor} mr-2"></span>
|
||||
</div>
|
||||
@@ -0,0 +1,120 @@
|
||||
<script lang="ts">
|
||||
import { link } from "svelte-spa-router";
|
||||
import { screenWidth, toggleTheme, isDarkMode, appTitle, isNarrow } from "../stores/theme";
|
||||
import { currentRoute } from "../stores/route";
|
||||
import { playgroundActivity } from "../stores/playgroundActivity";
|
||||
import ConnectionStatus from "./ConnectionStatus.svelte";
|
||||
|
||||
function handleTitleChange(newTitle: string): void {
|
||||
const sanitized = newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap";
|
||||
appTitle.set(sanitized);
|
||||
}
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent): void {
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
const target = e.currentTarget as HTMLElement;
|
||||
handleTitleChange(target.textContent || "(set title)");
|
||||
target.blur();
|
||||
}
|
||||
}
|
||||
|
||||
function handleBlur(e: FocusEvent): void {
|
||||
const target = e.currentTarget as HTMLElement;
|
||||
handleTitleChange(target.textContent || "(set title)");
|
||||
}
|
||||
|
||||
function isActive(path: string, current: string): boolean {
|
||||
return path === "/" ? current === "/" : current.startsWith(path);
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<header
|
||||
class="flex items-center justify-between bg-surface border-b border-border px-4 {$isNarrow
|
||||
? 'py-1 h-[60px]'
|
||||
: 'p-2 h-[75px]'}"
|
||||
>
|
||||
{#if $screenWidth !== "xs" && $screenWidth !== "sm"}
|
||||
<h1
|
||||
contenteditable="true"
|
||||
class="p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
|
||||
onblur={handleBlur}
|
||||
onkeydown={handleKeyDown}
|
||||
>
|
||||
{$appTitle}
|
||||
</h1>
|
||||
{/if}
|
||||
|
||||
<menu class="flex items-center gap-4 overflow-x-auto">
|
||||
<a
|
||||
href="/"
|
||||
use:link
|
||||
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
|
||||
>
|
||||
Playground
|
||||
</a>
|
||||
<a
|
||||
href="/models"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/models", $currentRoute)}
|
||||
>
|
||||
Models
|
||||
</a>
|
||||
<a
|
||||
href="/activity"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/activity", $currentRoute)}
|
||||
>
|
||||
Activity
|
||||
</a>
|
||||
<a
|
||||
href="/logs"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/logs", $currentRoute)}
|
||||
>
|
||||
Logs
|
||||
</a>
|
||||
<button onclick={toggleTheme} title="Toggle theme">
|
||||
{#if $isDarkMode}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M9.528 1.718a.75.75 0 0 1 .162.819A8.97 8.97 0 0 0 9 6a9 9 0 0 0 9 9 8.97 8.97 0 0 0 3.463-.69.75.75 0 0 1 .981.98 10.503 10.503 0 0 1-9.694 6.46c-5.799 0-10.5-4.7-10.5-10.5 0-4.368 2.667-8.112 6.46-9.694a.75.75 0 0 1 .818.162Z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path
|
||||
d="M12 2.25a.75.75 0 0 1 .75.75v2.25a.75.75 0 0 1-1.5 0V3a.75.75 0 0 1 .75-.75ZM7.5 12a4.5 4.5 0 1 1 9 0 4.5 4.5 0 0 1-9 0ZM18.894 6.166a.75.75 0 0 0-1.06-1.06l-1.591 1.59a.75.75 0 1 0 1.06 1.061l1.591-1.59ZM21.75 12a.75.75 0 0 1-.75.75h-2.25a.75.75 0 0 1 0-1.5H21a.75.75 0 0 1 .75.75ZM17.834 18.894a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 1 0-1.061 1.06l1.591 1.591ZM12 18a.75.75 0 0 1 .75.75V21a.75.75 0 0 1-1.5 0v-2.25A.75.75 0 0 1 12 18ZM7.758 17.303a.75.75 0 0 0-1.061-1.06l-1.591 1.59a.75.75 0 0 0 1.06 1.061l1.591-1.59ZM6 12a.75.75 0 0 1-.75.75H3a.75.75 0 0 1 0-1.5h2.25A.75.75 0 0 1 6 12ZM6.697 7.757a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 0 0-1.061 1.06l1.59 1.591Z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
<ConnectionStatus />
|
||||
</menu>
|
||||
</header>
|
||||
|
||||
<style>
|
||||
.activity-link {
|
||||
background: linear-gradient(90deg, #6366f1, #8b5cf6, #a855f7, #8b5cf6, #6366f1);
|
||||
background-size: 200% 100%;
|
||||
-webkit-background-clip: text;
|
||||
background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
animation: gradient-shift 2s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes gradient-shift {
|
||||
0% {
|
||||
background-position: 0% 50%;
|
||||
}
|
||||
100% {
|
||||
background-position: 200% 50%;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,139 @@
|
||||
<script lang="ts">
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
title: string;
|
||||
logData: string;
|
||||
}
|
||||
|
||||
let { id, title, logData }: Props = $props();
|
||||
|
||||
let filterRegex = $state("");
|
||||
|
||||
// Create persistent stores for this panel (id is intentionally captured at init time)
|
||||
// svelte-ignore state_referenced_locally
|
||||
const fontSizeStore = persistentStore<"xxs" | "xs" | "small" | "normal">(`logPanel-${id}-fontSize`, "normal");
|
||||
// svelte-ignore state_referenced_locally
|
||||
const wrapTextStore = persistentStore<boolean>(`logPanel-${id}-wrapText`, false);
|
||||
// svelte-ignore state_referenced_locally
|
||||
const showFilterStore = persistentStore<boolean>(`logPanel-${id}-showFilter`, false);
|
||||
|
||||
let textWrapClass = $derived($wrapTextStore ? "whitespace-pre-wrap" : "whitespace-pre");
|
||||
|
||||
function toggleFontSize(): void {
|
||||
fontSizeStore.update((prev) => {
|
||||
switch (prev) {
|
||||
case "xxs": return "xs";
|
||||
case "xs": return "small";
|
||||
case "small": return "normal";
|
||||
case "normal": return "xxs";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function toggleWrapText(): void {
|
||||
wrapTextStore.update((prev) => !prev);
|
||||
}
|
||||
|
||||
function toggleFilter(): void {
|
||||
if ($showFilterStore) {
|
||||
showFilterStore.set(false);
|
||||
filterRegex = "";
|
||||
} else {
|
||||
showFilterStore.set(true);
|
||||
}
|
||||
}
|
||||
|
||||
let fontSizeClass = $derived.by(() => {
|
||||
switch ($fontSizeStore) {
|
||||
case "xxs": return "text-[0.5rem]";
|
||||
case "xs": return "text-[0.75rem]";
|
||||
case "small": return "text-[0.875rem]";
|
||||
case "normal": return "text-base";
|
||||
}
|
||||
});
|
||||
|
||||
let filteredLogs = $derived.by(() => {
|
||||
if (!filterRegex) return logData;
|
||||
try {
|
||||
const regex = new RegExp(filterRegex, "i");
|
||||
return logData.split("\n").filter((line) => regex.test(line)).join("\n");
|
||||
} catch {
|
||||
return logData;
|
||||
}
|
||||
});
|
||||
|
||||
let preElement: HTMLPreElement;
|
||||
let userScrolledUp = $state(false);
|
||||
|
||||
function handleScroll() {
|
||||
if (!preElement) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = preElement;
|
||||
userScrolledUp = scrollHeight - scrollTop - clientHeight > 40;
|
||||
}
|
||||
|
||||
// Auto scroll to bottom when logs change, unless user has scrolled up
|
||||
$effect(() => {
|
||||
if (preElement && filteredLogs && !userScrolledUp) {
|
||||
preElement.scrollTop = preElement.scrollHeight;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="rounded-lg overflow-hidden flex flex-col bg-gray-950/5 dark:bg-white/10 h-full w-full p-1">
|
||||
<div class="p-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<h3 class="m-0 text-lg p-0">{title}</h3>
|
||||
|
||||
<div class="flex gap-2 items-center">
|
||||
<button class="btn border-0" onclick={toggleFontSize} title="Change font size">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path d="M2 4v3h5v12h3V7h5V4H2zm19 5h-9v3h3v7h3v-7h3V9z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button class="btn border-0" onclick={toggleWrapText} title="Toggle text wrap">
|
||||
{#if $wrapTextStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h10.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
<button class="btn border-0" onclick={toggleFilter} title="Toggle filter">
|
||||
{#if $showFilterStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M10.5 3.75a6.75 6.75 0 1 0 0 13.5 6.75 6.75 0 0 0 0-13.5ZM2.25 10.5a8.25 8.25 0 1 1 14.59 5.28l4.69 4.69a.75.75 0 1 1-1.06 1.06l-4.69-4.69A8.25 8.25 0 0 1 2.25 10.5Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-4 h-4">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="m21 21-5.197-5.197m0 0A7.5 7.5 0 1 0 5.196 5.196a7.5 7.5 0 0 0 10.607 10.607Z" />
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if $showFilterStore}
|
||||
<div class="mt-2 flex gap-2 items-center w-full">
|
||||
<input
|
||||
type="text"
|
||||
class="w-full text-sm border border-gray-950/10 dark:border-white/5 p-2 rounded outline-none"
|
||||
placeholder="Filter logs (regex)..."
|
||||
bind:value={filterRegex}
|
||||
/>
|
||||
<button class="pl-2" onclick={() => (filterRegex = "")} aria-label="Clear filter">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm-1.72 6.97a.75.75 0 1 0-1.06 1.06L10.94 12l-1.72 1.72a.75.75 0 1 0 1.06 1.06L12 13.06l1.72 1.72a.75.75 0 1 0 1.06-1.06L13.06 12l1.72-1.72a.75.75 0 1 0-1.06-1.06L12 10.94l-1.72-1.72Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="rounded-lg bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||
<pre bind:this={preElement} onscroll={handleScroll} class="{textWrapClass} {fontSizeClass} h-full overflow-auto p-4">{filteredLogs}</pre>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,211 @@
|
||||
<script lang="ts">
|
||||
import { models, loadModel, unloadAllModels, unloadSingleModel } from "../stores/api";
|
||||
import { isNarrow } from "../stores/theme";
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
import type { Model } from "../lib/types";
|
||||
|
||||
let isUnloading = $state(false);
|
||||
let menuOpen = $state(false);
|
||||
|
||||
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
|
||||
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
|
||||
|
||||
let filteredModels = $derived.by(() => {
|
||||
const filtered = $models.filter((model) => $showUnlistedStore || !model.unlisted);
|
||||
const peerModels = filtered.filter((m) => m.peerID);
|
||||
|
||||
// Group peer models by peerID
|
||||
const grouped = peerModels.reduce(
|
||||
(acc, model) => {
|
||||
const peerId = model.peerID || "unknown";
|
||||
if (!acc[peerId]) acc[peerId] = [];
|
||||
acc[peerId].push(model);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, Model[]>
|
||||
);
|
||||
|
||||
return {
|
||||
regularModels: filtered.filter((m) => !m.peerID),
|
||||
peerModelsByPeerId: grouped,
|
||||
};
|
||||
});
|
||||
|
||||
async function handleUnloadAllModels(): Promise<void> {
|
||||
isUnloading = true;
|
||||
try {
|
||||
await unloadAllModels();
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
setTimeout(() => (isUnloading = false), 1000);
|
||||
}
|
||||
}
|
||||
|
||||
function toggleIdorName(): void {
|
||||
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
|
||||
}
|
||||
|
||||
function toggleShowUnlisted(): void {
|
||||
showUnlistedStore.update((prev) => !prev);
|
||||
}
|
||||
|
||||
function getModelDisplay(model: Model): string {
|
||||
return $showIdorNameStore === "id" ? model.id : (model.name || model.id);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="card h-full flex flex-col">
|
||||
<div class="shrink-0">
|
||||
<div class="flex justify-between items-baseline">
|
||||
<h2 class={$isNarrow ? "text-xl" : ""}>Models</h2>
|
||||
{#if $isNarrow}
|
||||
<div class="relative">
|
||||
<button class="btn text-base flex items-center gap-2 py-1" onclick={() => (menuOpen = !menuOpen)} aria-label="Toggle menu">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
{#if menuOpen}
|
||||
<div class="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { toggleIdorName(); menuOpen = false; }}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{$showIdorNameStore === "id" ? "Show Name" : "Show ID"}
|
||||
</button>
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { toggleShowUnlisted(); menuOpen = false; }}
|
||||
>
|
||||
{#if $showUnlistedStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
|
||||
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
|
||||
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
|
||||
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{/if}
|
||||
{$showUnlistedStore ? "Hide Unlisted" : "Show Unlisted"}
|
||||
</button>
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { handleUnloadAllModels(); menuOpen = false; }}
|
||||
disabled={isUnloading}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{#if !$isNarrow}
|
||||
<div class="flex justify-between">
|
||||
<div class="flex gap-2">
|
||||
<button class="btn text-base flex items-center gap-2" onclick={toggleIdorName} style="line-height: 1.2">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{$showIdorNameStore === "id" ? "ID" : "Name"}
|
||||
</button>
|
||||
|
||||
<button class="btn text-base flex items-center gap-2" onclick={toggleShowUnlisted} style="line-height: 1.2">
|
||||
{#if $showUnlistedStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
|
||||
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
|
||||
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
|
||||
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
|
||||
</svg>
|
||||
{/if}
|
||||
unlisted
|
||||
</button>
|
||||
</div>
|
||||
<button class="btn text-base flex items-center gap-2" onclick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex-1 overflow-y-auto">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 bg-card z-10">
|
||||
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||
<th>{$showIdorNameStore === "id" ? "Model ID" : "Name"}</th>
|
||||
<th></th>
|
||||
<th>State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each filteredModels.regularModels as model (model.id)}
|
||||
<tr class="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td class={model.unlisted ? "text-txtsecondary" : ""}>
|
||||
<a href="/upstream/{model.id}/" class="font-semibold" target="_blank">
|
||||
{getModelDisplay(model)}
|
||||
</a>
|
||||
{#if model.description}
|
||||
<p class={model.unlisted ? "text-opacity-70" : ""}><em>{model.description}</em></p>
|
||||
{/if}
|
||||
{#if model.aliases && model.aliases.length > 0}
|
||||
<p class="text-xs text-txtsecondary">Aliases: {model.aliases.join(", ")}</p>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="w-12">
|
||||
{#if model.state === "stopped"}
|
||||
<button class="btn btn--sm" onclick={() => loadModel(model.id)}>Load</button>
|
||||
{:else}
|
||||
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="w-20">
|
||||
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{#if Object.keys(filteredModels.peerModelsByPeerId).length > 0}
|
||||
<h3 class="mt-8 mb-2">Peer Models</h3>
|
||||
{#each Object.entries(filteredModels.peerModelsByPeerId).sort(([a], [b]) => a.localeCompare(b)) as [peerId, peerModels] (peerId)}
|
||||
<div class="mb-4">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 bg-card z-10">
|
||||
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||
<th class="font-semibold">{peerId}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each peerModels as model (model.id)}
|
||||
<tr class="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td class="pl-8 {model.unlisted ? 'text-txtsecondary' : ''}">
|
||||
<span>{model.id}</span>
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,152 @@
|
||||
<script lang="ts">
|
||||
import type { Snippet } from "svelte";
|
||||
import { onMount } from "svelte";
|
||||
|
||||
interface Props {
|
||||
direction: "horizontal" | "vertical";
|
||||
storageKey: string;
|
||||
leftPanel: Snippet;
|
||||
rightPanel: Snippet;
|
||||
defaultSize?: number;
|
||||
minSize?: number;
|
||||
}
|
||||
|
||||
let { direction, storageKey, leftPanel, rightPanel, defaultSize = 50, minSize = 5 }: Props = $props();
|
||||
|
||||
let containerRef: HTMLDivElement;
|
||||
let isDragging = $state(false);
|
||||
// svelte-ignore state_referenced_locally
|
||||
let leftSize = $state(defaultSize);
|
||||
|
||||
// Load saved size from localStorage
|
||||
onMount(() => {
|
||||
const saved = localStorage.getItem(`panel-size-${storageKey}`);
|
||||
if (saved) {
|
||||
const parsed = parseFloat(saved);
|
||||
if (!isNaN(parsed) && parsed >= minSize && parsed <= 100 - minSize) {
|
||||
leftSize = parsed;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
function saveSize(): void {
|
||||
localStorage.setItem(`panel-size-${storageKey}`, String(leftSize));
|
||||
}
|
||||
|
||||
function handleMouseDown(e: MouseEvent): void {
|
||||
e.preventDefault();
|
||||
isDragging = true;
|
||||
document.addEventListener("mousemove", handleMouseMove);
|
||||
document.addEventListener("mouseup", handleMouseUp);
|
||||
}
|
||||
|
||||
function handleTouchStart(_e: TouchEvent): void {
|
||||
isDragging = true;
|
||||
document.addEventListener("touchmove", handleTouchMove);
|
||||
document.addEventListener("touchend", handleTouchEnd);
|
||||
}
|
||||
|
||||
function handleMouseMove(e: MouseEvent): void {
|
||||
if (!isDragging || !containerRef) return;
|
||||
updateSize(e.clientX, e.clientY);
|
||||
}
|
||||
|
||||
function handleTouchMove(e: TouchEvent): void {
|
||||
if (!isDragging || !containerRef || e.touches.length === 0) return;
|
||||
updateSize(e.touches[0].clientX, e.touches[0].clientY);
|
||||
}
|
||||
|
||||
function updateSize(clientX: number, clientY: number): void {
|
||||
const rect = containerRef.getBoundingClientRect();
|
||||
|
||||
let newSize: number;
|
||||
if (direction === "horizontal") {
|
||||
newSize = ((clientX - rect.left) / rect.width) * 100;
|
||||
} else {
|
||||
newSize = ((clientY - rect.top) / rect.height) * 100;
|
||||
}
|
||||
|
||||
// Clamp size
|
||||
newSize = Math.max(minSize, Math.min(100 - minSize, newSize));
|
||||
leftSize = newSize;
|
||||
}
|
||||
|
||||
function handleMouseUp(): void {
|
||||
isDragging = false;
|
||||
saveSize();
|
||||
document.removeEventListener("mousemove", handleMouseMove);
|
||||
document.removeEventListener("mouseup", handleMouseUp);
|
||||
}
|
||||
|
||||
function handleTouchEnd(): void {
|
||||
isDragging = false;
|
||||
saveSize();
|
||||
document.removeEventListener("touchmove", handleTouchMove);
|
||||
document.removeEventListener("touchend", handleTouchEnd);
|
||||
}
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent): void {
|
||||
const step = 2; // 2% increment for keyboard navigation
|
||||
const key = e.key;
|
||||
|
||||
if (direction === "horizontal" && (key === "ArrowLeft" || key === "ArrowRight")) {
|
||||
e.preventDefault();
|
||||
const delta = key === "ArrowLeft" ? -step : step;
|
||||
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
|
||||
leftSize = newSize;
|
||||
saveSize();
|
||||
} else if (direction === "vertical" && (key === "ArrowUp" || key === "ArrowDown")) {
|
||||
e.preventDefault();
|
||||
const delta = key === "ArrowUp" ? -step : step;
|
||||
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
|
||||
leftSize = newSize;
|
||||
saveSize();
|
||||
}
|
||||
}
|
||||
|
||||
let containerClass = $derived(direction === "horizontal" ? "flex-row" : "flex-col");
|
||||
|
||||
let handleClass = $derived(
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full cursor-col-resize"
|
||||
: "w-full h-2 cursor-row-resize"
|
||||
);
|
||||
|
||||
let leftStyle = $derived(
|
||||
direction === "horizontal"
|
||||
? `width: ${leftSize}%; min-width: ${minSize}%`
|
||||
: `height: ${leftSize}%; min-height: ${minSize}%`
|
||||
);
|
||||
|
||||
let rightStyle = $derived(
|
||||
direction === "horizontal"
|
||||
? `width: ${100 - leftSize}%; min-width: ${minSize}%`
|
||||
: `height: ${100 - leftSize}%; min-height: ${minSize}%`
|
||||
);
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="flex {containerClass} h-full w-full gap-2">
|
||||
<div style={leftStyle} class="overflow-hidden">
|
||||
{@render leftPanel()}
|
||||
</div>
|
||||
|
||||
<!-- svelte-ignore a11y_no_noninteractive_tabindex -->
|
||||
<!-- svelte-ignore a11y_no_noninteractive_element_interactions -->
|
||||
<div
|
||||
role="separator"
|
||||
tabindex="0"
|
||||
class="{handleClass} bg-primary hover:bg-success transition-colors rounded flex-shrink-0"
|
||||
onmousedown={handleMouseDown}
|
||||
ontouchstart={handleTouchStart}
|
||||
onkeydown={handleKeyDown}
|
||||
aria-label="Resize panels"
|
||||
aria-orientation={direction}
|
||||
aria-valuenow={Math.round(leftSize)}
|
||||
aria-valuemin={minSize}
|
||||
aria-valuemax={100 - minSize}
|
||||
></div>
|
||||
|
||||
<div style={rightStyle} class="overflow-hidden">
|
||||
{@render rightPanel()}
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,167 @@
|
||||
<script lang="ts">
|
||||
import { inFlightRequests, metrics } from "../stores/api";
|
||||
import TokenHistogram from "./TokenHistogram.svelte";
|
||||
|
||||
interface HistogramData {
|
||||
bins: number[];
|
||||
min: number;
|
||||
max: number;
|
||||
binSize: number;
|
||||
p99: number;
|
||||
p95: number;
|
||||
p50: number;
|
||||
}
|
||||
|
||||
let stats = $derived.by(() => {
|
||||
const totalRequests = $metrics.length;
|
||||
if (totalRequests === 0) {
|
||||
return {
|
||||
totalRequests: 0,
|
||||
totalInputTokens: 0,
|
||||
totalOutputTokens: 0,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: { p99: "0", p95: "0", p50: "0" },
|
||||
histogramData: null,
|
||||
};
|
||||
}
|
||||
|
||||
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||
|
||||
// Calculate token statistics using output_tokens and duration_ms
|
||||
const validMetrics = $metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
|
||||
if (validMetrics.length === 0) {
|
||||
return {
|
||||
totalRequests,
|
||||
totalInputTokens,
|
||||
totalOutputTokens,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: { p99: "0", p95: "0", p50: "0" },
|
||||
histogramData: null,
|
||||
};
|
||||
}
|
||||
|
||||
// Calculate tokens/second for each valid metric
|
||||
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
|
||||
|
||||
// Sort for percentile calculation
|
||||
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
|
||||
|
||||
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
|
||||
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
|
||||
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
|
||||
|
||||
// Create histogram data
|
||||
const min = Math.min(...tokensPerSecond);
|
||||
const max = Math.max(...tokensPerSecond);
|
||||
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5)));
|
||||
const binSize = (max - min) / binCount;
|
||||
|
||||
const bins = Array(binCount).fill(0);
|
||||
tokensPerSecond.forEach((value) => {
|
||||
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
|
||||
bins[binIndex]++;
|
||||
});
|
||||
|
||||
const histogramData: HistogramData = {
|
||||
bins,
|
||||
min,
|
||||
max,
|
||||
binSize,
|
||||
p99,
|
||||
p95,
|
||||
p50,
|
||||
};
|
||||
|
||||
return {
|
||||
totalRequests,
|
||||
totalInputTokens,
|
||||
totalOutputTokens,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: {
|
||||
p99: p99.toFixed(2),
|
||||
p95: p95.toFixed(2),
|
||||
p50: p50.toFixed(2),
|
||||
},
|
||||
histogramData,
|
||||
};
|
||||
});
|
||||
|
||||
const nf = new Intl.NumberFormat();
|
||||
</script>
|
||||
|
||||
<div class="card">
|
||||
<div class="rounded-lg overflow-hidden border border-card-border-inner">
|
||||
<table class="min-w-full divide-y divide-card-border-inner">
|
||||
<thead class="bg-secondary">
|
||||
<tr>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">Requests</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Processed
|
||||
</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Generated
|
||||
</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Token Stats (tokens/sec)
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
|
||||
<tbody class="bg-surface divide-y divide-card-border-inner">
|
||||
<tr class="hover:bg-secondary">
|
||||
<td class="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">
|
||||
<div class="flex flex-col gap-1">
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Completed: {nf.format(stats.totalRequests)}</span>
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Waiting: {nf.format(stats.inFlightRequests)}</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-sm font-medium">{nf.format(stats.totalInputTokens)}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-sm font-medium">{nf.format(stats.totalOutputTokens)}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="space-y-3">
|
||||
<div class="grid grid-cols-3 gap-2 items-center">
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P50</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p50}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P95</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p95}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P99</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p99}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{#if stats.histogramData}
|
||||
<TokenHistogram data={stats.histogramData} />
|
||||
{/if}
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,129 @@
|
||||
<script lang="ts">
|
||||
interface HistogramData {
|
||||
bins: number[];
|
||||
min: number;
|
||||
max: number;
|
||||
binSize: number;
|
||||
p99: number;
|
||||
p95: number;
|
||||
p50: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
data: HistogramData;
|
||||
}
|
||||
|
||||
let { data }: Props = $props();
|
||||
|
||||
const height = 120;
|
||||
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
|
||||
const viewBoxWidth = 600;
|
||||
const chartWidth = viewBoxWidth - padding.left - padding.right;
|
||||
const chartHeight = height - padding.top - padding.bottom;
|
||||
|
||||
let maxCount = $derived(Math.max(...data.bins));
|
||||
let barWidth = $derived(chartWidth / data.bins.length);
|
||||
let range = $derived(data.max - data.min);
|
||||
|
||||
function getXPosition(value: number): number {
|
||||
return padding.left + ((value - data.min) / range) * chartWidth;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="mt-2 w-full">
|
||||
<svg viewBox="0 0 {viewBoxWidth} {height}" class="w-full h-auto" preserveAspectRatio="xMidYMid meet">
|
||||
<!-- Y-axis -->
|
||||
<line
|
||||
x1={padding.left}
|
||||
y1={padding.top}
|
||||
x2={padding.left}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="1"
|
||||
opacity="0.3"
|
||||
/>
|
||||
|
||||
<!-- X-axis -->
|
||||
<line
|
||||
x1={padding.left}
|
||||
y1={height - padding.bottom}
|
||||
x2={viewBoxWidth - padding.right}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="1"
|
||||
opacity="0.3"
|
||||
/>
|
||||
|
||||
<!-- Histogram bars -->
|
||||
{#each data.bins as count, i}
|
||||
{@const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0}
|
||||
{@const x = padding.left + i * barWidth}
|
||||
{@const y = height - padding.bottom - barHeight}
|
||||
{@const binStart = data.min + i * data.binSize}
|
||||
{@const binEnd = binStart + data.binSize}
|
||||
<g>
|
||||
<rect
|
||||
{x}
|
||||
{y}
|
||||
width={Math.max(barWidth - 1, 1)}
|
||||
height={barHeight}
|
||||
fill="currentColor"
|
||||
opacity="0.6"
|
||||
class="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
|
||||
/>
|
||||
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
|
||||
</g>
|
||||
{/each}
|
||||
|
||||
<!-- Percentile lines -->
|
||||
<line
|
||||
x1={getXPosition(data.p50)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p50)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-gray-600 dark:text-gray-400"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1={getXPosition(data.p95)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p95)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-orange-500 dark:text-orange-400"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1={getXPosition(data.p99)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p99)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-green-500 dark:text-green-400"
|
||||
/>
|
||||
|
||||
<!-- X-axis labels -->
|
||||
<text x={padding.left} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="start">
|
||||
{data.min.toFixed(1)}
|
||||
</text>
|
||||
|
||||
<text x={viewBoxWidth - padding.right} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="end">
|
||||
{data.max.toFixed(1)}
|
||||
</text>
|
||||
|
||||
<!-- X-axis label -->
|
||||
<text x={padding.left + chartWidth / 2} y={height - 2} font-size="10" fill="currentColor" opacity="0.6" text-anchor="middle">
|
||||
Tokens/Second Distribution
|
||||
</text>
|
||||
</svg>
|
||||
</div>
|
||||
@@ -0,0 +1,20 @@
|
||||
<script lang="ts">
|
||||
interface Props {
|
||||
content: string;
|
||||
}
|
||||
|
||||
let { content }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative group inline-block">
|
||||
<span class="cursor-help">ⓘ</span>
|
||||
<div
|
||||
class="absolute top-full left-1/2 transform -translate-x-1/2 mt-2
|
||||
px-3 py-2 bg-gray-900 text-white text-sm rounded-md
|
||||
opacity-0 group-hover:opacity-100 transition-opacity
|
||||
duration-200 pointer-events-none whitespace-nowrap z-50 normal-case"
|
||||
>
|
||||
{content}
|
||||
<div class="absolute bottom-full left-1/2 transform -translate-x-1/2 border-4 border-transparent border-b-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,256 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { transcribeAudio } from "../../lib/audioApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-audio-model", "");
|
||||
|
||||
let selectedFile = $state<File | null>(null);
|
||||
let isTranscribing = $state(false);
|
||||
let transcriptionResult = $state<string | null>(null);
|
||||
let error = $state<string | null>(null);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let isDragging = $state(false);
|
||||
let fileInput = $state<HTMLInputElement | null>(null);
|
||||
let copied = $state(false);
|
||||
|
||||
const ACCEPTED_FORMATS = ['.mp3', '.wav', '.ogg'];
|
||||
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
|
||||
let canTranscribe = $derived(selectedFile !== null && $selectedModelStore !== "" && !isTranscribing);
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.audioTranscribing.set(isTranscribing);
|
||||
});
|
||||
|
||||
function validateFile(file: File): { valid: boolean; error?: string } {
|
||||
const ext = '.' + file.name.split('.').pop()?.toLowerCase();
|
||||
|
||||
if (!ACCEPTED_FORMATS.includes(ext)) {
|
||||
return { valid: false, error: 'Invalid file type. Accepted: MP3, WAV, OGG' };
|
||||
}
|
||||
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
return { valid: false, error: 'File too large. Maximum: 25MB' };
|
||||
}
|
||||
|
||||
return { valid: true };
|
||||
}
|
||||
|
||||
function handleFileSelect(event: Event) {
|
||||
const target = event.target as HTMLInputElement;
|
||||
const file = target.files?.[0];
|
||||
if (file) {
|
||||
const validation = validateFile(file);
|
||||
if (validation.valid) {
|
||||
selectedFile = file;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
} else {
|
||||
error = validation.error || "Invalid file";
|
||||
selectedFile = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
isDragging = true;
|
||||
}
|
||||
|
||||
function handleDragLeave() {
|
||||
isDragging = false;
|
||||
}
|
||||
|
||||
function handleDrop(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
isDragging = false;
|
||||
|
||||
const file = event.dataTransfer?.files[0];
|
||||
if (file) {
|
||||
const validation = validateFile(file);
|
||||
if (validation.valid) {
|
||||
selectedFile = file;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
} else {
|
||||
error = validation.error || "Invalid file";
|
||||
selectedFile = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function transcribe() {
|
||||
if (!selectedFile || !$selectedModelStore || isTranscribing) return;
|
||||
|
||||
isTranscribing = true;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
const response = await transcribeAudio(
|
||||
$selectedModelStore,
|
||||
selectedFile,
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
transcriptionResult = response.text;
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// User cancelled
|
||||
} else {
|
||||
error = err instanceof Error ? err.message : "An error occurred";
|
||||
}
|
||||
} finally {
|
||||
isTranscribing = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function cancelTranscription() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function clearAll() {
|
||||
selectedFile = null;
|
||||
transcriptionResult = null;
|
||||
error = null;
|
||||
if (fileInput) {
|
||||
fileInput.value = '';
|
||||
}
|
||||
}
|
||||
|
||||
function copyToClipboard() {
|
||||
if (transcriptionResult) {
|
||||
navigator.clipboard.writeText(transcriptionResult);
|
||||
copied = true;
|
||||
setTimeout(() => {
|
||||
copied = false;
|
||||
}, 2000);
|
||||
}
|
||||
}
|
||||
|
||||
function formatFileSize(bytes: number): string {
|
||||
if (bytes < 1024) return bytes + ' B';
|
||||
if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB';
|
||||
return (bytes / (1024 * 1024)).toFixed(1) + ' MB';
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} />
|
||||
</div>
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to transcribe audio.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- File upload / Result display area -->
|
||||
<div class="flex-1 overflow-auto mb-4 flex items-center justify-center bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
{#if isTranscribing}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
|
||||
<p>Transcribing audio...</p>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="text-center text-red-500 p-4">
|
||||
<p class="font-medium">Error</p>
|
||||
<p class="text-sm mt-1">{error}</p>
|
||||
</div>
|
||||
{:else if transcriptionResult}
|
||||
<div class="w-full h-full flex flex-col p-4">
|
||||
<div class="flex justify-between items-center mb-2">
|
||||
<h3 class="font-medium">Transcription Result</h3>
|
||||
<button
|
||||
class="btn btn-sm"
|
||||
onclick={copyToClipboard}
|
||||
title={copied ? 'Copied!' : 'Copy to clipboard'}
|
||||
>
|
||||
{#if copied}
|
||||
<svg class="w-5 h-5 text-green-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"></path>
|
||||
</svg>
|
||||
{:else}
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div class="flex-1 overflow-auto p-3 rounded border border-gray-200 dark:border-white/10 bg-background whitespace-pre-wrap">
|
||||
{transcriptionResult}
|
||||
</div>
|
||||
</div>
|
||||
{:else if selectedFile}
|
||||
<div class="text-center text-txtsecondary p-4">
|
||||
<p class="font-medium mb-2">File Selected</p>
|
||||
<p class="text-sm">{selectedFile.name}</p>
|
||||
<p class="text-xs mt-1">{formatFileSize(selectedFile.size)}</p>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
role="region"
|
||||
aria-label="Audio file drop zone"
|
||||
class="w-full h-full flex items-center justify-center text-center text-txtsecondary p-8 {isDragging ? 'bg-primary/10' : ''}"
|
||||
ondragover={handleDragOver}
|
||||
ondragleave={handleDragLeave}
|
||||
ondrop={handleDrop}
|
||||
>
|
||||
<div>
|
||||
<p class="mb-2">Drag and drop an audio file here</p>
|
||||
<p class="text-sm">or use the Browse button below</p>
|
||||
<p class="text-xs mt-4">Accepted formats: MP3, WAV, OGG (max 25MB)</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- File input and transcribe button -->
|
||||
<div class="shrink-0 flex gap-2">
|
||||
<input
|
||||
type="file"
|
||||
accept=".mp3,.wav,.ogg"
|
||||
class="hidden"
|
||||
onchange={handleFileSelect}
|
||||
bind:this={fileInput}
|
||||
/>
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => fileInput?.click()}
|
||||
disabled={isTranscribing}
|
||||
>
|
||||
Browse Files
|
||||
</button>
|
||||
<div class="flex-1"></div>
|
||||
{#if isTranscribing}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancelTranscription}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={transcribe}
|
||||
disabled={!canTranscribe}
|
||||
>
|
||||
Transcribe
|
||||
</button>
|
||||
<button
|
||||
class="btn"
|
||||
onclick={clearAll}
|
||||
disabled={!selectedFile && !transcriptionResult && !error}
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,466 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { streamChatCompletion } from "../../lib/chatApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import type { ChatMessage, ContentPart } from "../../lib/types";
|
||||
import ChatMessageComponent from "./ChatMessage.svelte";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
import ExpandableTextarea from "./ExpandableTextarea.svelte";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-selected-model", "");
|
||||
const systemPromptStore = persistentStore<string>("playground-system-prompt", "");
|
||||
const temperatureStore = persistentStore<number>("playground-temperature", 0.7);
|
||||
|
||||
function loadMessages(): ChatMessage[] {
|
||||
try {
|
||||
const saved = localStorage.getItem("playground-messages");
|
||||
return saved ? JSON.parse(saved) : [];
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
let messages = $state<ChatMessage[]>(loadMessages());
|
||||
let userInput = $state("");
|
||||
let isStreaming = $state(false);
|
||||
let isReasoning = $state(false);
|
||||
let reasoningStartTime = $state<number>(0);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let messagesContainer: HTMLDivElement | undefined = $state();
|
||||
let showSettings = $state(false);
|
||||
let attachedImages = $state<string[]>([]);
|
||||
let fileInput = $state<HTMLInputElement | null>(null);
|
||||
let imageError = $state<string | null>(null);
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
let userScrolledUp = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.chatStreaming.set(isStreaming);
|
||||
});
|
||||
|
||||
function handleMessagesScroll() {
|
||||
if (!messagesContainer) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = messagesContainer;
|
||||
// Consider "at bottom" if within 40px of the bottom
|
||||
userScrolledUp = scrollHeight - scrollTop - clientHeight > 40;
|
||||
}
|
||||
|
||||
// Auto-scroll when messages change — skip if user scrolled up
|
||||
$effect(() => {
|
||||
if (messages.length > 0 && messagesContainer && !userScrolledUp) {
|
||||
messagesContainer.scrollTo({
|
||||
top: messagesContainer.scrollHeight,
|
||||
behavior: isStreaming ? "instant" : "smooth",
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Persist messages to localStorage (throttled to once per 2s)
|
||||
let lastSaveTime = 0;
|
||||
$effect(() => {
|
||||
const json = JSON.stringify(messages);
|
||||
const elapsed = Date.now() - lastSaveTime;
|
||||
const save = () => {
|
||||
try { localStorage.setItem("playground-messages", json); } catch {}
|
||||
lastSaveTime = Date.now();
|
||||
};
|
||||
if (elapsed >= 2000) {
|
||||
save();
|
||||
return;
|
||||
}
|
||||
const timer = setTimeout(save, 2000 - elapsed);
|
||||
return () => clearTimeout(timer);
|
||||
});
|
||||
|
||||
async function sendMessage() {
|
||||
const trimmedInput = userInput.trim();
|
||||
if ((!trimmedInput && attachedImages.length === 0) || !$selectedModelStore || isStreaming) return;
|
||||
|
||||
userScrolledUp = false;
|
||||
|
||||
// Build message content (multimodal if images attached)
|
||||
let content: string | ContentPart[];
|
||||
if (attachedImages.length > 0) {
|
||||
const parts: ContentPart[] = [];
|
||||
if (trimmedInput) {
|
||||
parts.push({ type: "text", text: trimmedInput });
|
||||
}
|
||||
for (const url of attachedImages) {
|
||||
parts.push({ type: "image_url", image_url: { url } });
|
||||
}
|
||||
content = parts;
|
||||
} else {
|
||||
content = trimmedInput;
|
||||
}
|
||||
|
||||
// Add user message
|
||||
messages = [...messages, { role: "user", content }];
|
||||
userInput = "";
|
||||
attachedImages = [];
|
||||
imageError = null;
|
||||
|
||||
// Generate response from the new user message
|
||||
await regenerateFromIndex(messages.length - 1);
|
||||
}
|
||||
|
||||
function cancelStreaming() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function newChat() {
|
||||
if (isStreaming) {
|
||||
cancelStreaming();
|
||||
}
|
||||
messages = [];
|
||||
isReasoning = false;
|
||||
reasoningStartTime = 0;
|
||||
}
|
||||
|
||||
async function regenerateFromIndex(idx: number) {
|
||||
// Remove all messages after the edited user message
|
||||
messages = messages.slice(0, idx + 1);
|
||||
|
||||
// Add empty assistant message for the new response
|
||||
messages = [...messages, { role: "assistant", content: "" }];
|
||||
|
||||
isStreaming = true;
|
||||
isReasoning = false;
|
||||
reasoningStartTime = 0;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
// Build messages array with optional system prompt
|
||||
const apiMessages: ChatMessage[] = [];
|
||||
if ($systemPromptStore.trim()) {
|
||||
apiMessages.push({ role: "system", content: $systemPromptStore.trim() });
|
||||
}
|
||||
apiMessages.push(...messages.slice(0, -1)); // Add all messages except the empty assistant one
|
||||
|
||||
const stream = streamChatCompletion(
|
||||
$selectedModelStore,
|
||||
apiMessages,
|
||||
abortController.signal,
|
||||
{ temperature: $temperatureStore }
|
||||
);
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.done) break;
|
||||
|
||||
// Handle reasoning content
|
||||
if (chunk.reasoning_content) {
|
||||
// Start timing on first reasoning content
|
||||
if (!isReasoning) {
|
||||
isReasoning = true;
|
||||
reasoningStartTime = Date.now();
|
||||
}
|
||||
|
||||
// Update the last message with reasoning content
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoning_content: (msg.reasoning_content || "") + chunk.reasoning_content }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
|
||||
// Handle regular content - end reasoning phase when we get content
|
||||
if (chunk.content) {
|
||||
if (isReasoning) {
|
||||
// Calculate reasoning time
|
||||
const reasoningTimeMs = Date.now() - reasoningStartTime;
|
||||
isReasoning = false;
|
||||
|
||||
// Update message with reasoning time
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoningTimeMs }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
|
||||
// Update the last message (assistant) with new content
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, content: msg.content + chunk.content }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
// User cancelled, keep partial response
|
||||
// If we were still reasoning, record the time
|
||||
if (isReasoning && reasoningStartTime > 0) {
|
||||
const reasoningTimeMs = Date.now() - reasoningStartTime;
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoningTimeMs }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Show error in the assistant message
|
||||
const errorMessage = error instanceof Error ? error.message : "An error occurred";
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, content: msg.content + `\n\n**Error:** ${errorMessage}` }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
isStreaming = false;
|
||||
isReasoning = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
async function editMessage(idx: number, newContent: string) {
|
||||
if (isStreaming || !$selectedModelStore) return;
|
||||
|
||||
// Update the user message at the specified index
|
||||
messages = messages.map((msg, i) =>
|
||||
i === idx ? { ...msg, content: newContent } : msg
|
||||
);
|
||||
|
||||
// Trigger a new chat request with the updated messages
|
||||
await regenerateFromIndex(idx);
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
sendMessage();
|
||||
}
|
||||
}
|
||||
|
||||
const ACCEPTED_IMAGE_FORMATS = ["image/jpeg", "image/png", "image/gif", "image/webp"];
|
||||
const MAX_IMAGE_SIZE = 20 * 1024 * 1024; // 20MB
|
||||
const MAX_IMAGES_PER_MESSAGE = 5;
|
||||
|
||||
function validateImageFile(file: File): string | null {
|
||||
if (!ACCEPTED_IMAGE_FORMATS.includes(file.type)) {
|
||||
return `Invalid file type: ${file.type}. Accepted formats: JPG, PNG, GIF, WEBP`;
|
||||
}
|
||||
if (file.size > MAX_IMAGE_SIZE) {
|
||||
return `File too large: ${(file.size / 1024 / 1024).toFixed(1)}MB. Maximum size: 20MB`;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function fileToDataUrl(file: File): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => resolve(reader.result as string);
|
||||
reader.onerror = () => reject(new Error("Failed to read file"));
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function processImageFiles(files: File[]): Promise<void> {
|
||||
imageError = null;
|
||||
|
||||
if (attachedImages.length + files.length > MAX_IMAGES_PER_MESSAGE) {
|
||||
imageError = `Maximum ${MAX_IMAGES_PER_MESSAGE} images per message`;
|
||||
return;
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
const error = validateImageFile(file);
|
||||
if (error) {
|
||||
imageError = error;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const dataUrls = await Promise.all(files.map(fileToDataUrl));
|
||||
attachedImages = [...attachedImages, ...dataUrls];
|
||||
} catch (error) {
|
||||
imageError = error instanceof Error ? error.message : "Failed to process images";
|
||||
}
|
||||
}
|
||||
|
||||
function handleImageSelect(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
if (input.files && input.files.length > 0) {
|
||||
processImageFiles(Array.from(input.files));
|
||||
}
|
||||
// Reset the input so the same file can be selected again
|
||||
input.value = "";
|
||||
}
|
||||
|
||||
function removeImage(idx: number) {
|
||||
attachedImages = attachedImages.filter((_, i) => i !== idx);
|
||||
imageError = null;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector and controls -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a model..." disabled={isStreaming} />
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => (showSettings = !showSettings)}
|
||||
title="Settings"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M8.34 1.804A1 1 0 0 1 9.32 1h1.36a1 1 0 0 1 .98.804l.295 1.473c.497.144.971.342 1.416.587l1.25-.834a1 1 0 0 1 1.262.125l.962.962a1 1 0 0 1 .125 1.262l-.834 1.25c.245.445.443.919.587 1.416l1.473.295a1 1 0 0 1 .804.98v1.36a1 1 0 0 1-.804.98l-1.473.295a6.95 6.95 0 0 1-.587 1.416l.834 1.25a1 1 0 0 1-.125 1.262l-.962.962a1 1 0 0 1-1.262.125l-1.25-.834a6.953 6.953 0 0 1-1.416.587l-.295 1.473a1 1 0 0 1-.98.804H9.32a1 1 0 0 1-.98-.804l-.295-1.473a6.957 6.957 0 0 1-1.416-.587l-1.25.834a1 1 0 0 1-1.262-.125l-.962-.962a1 1 0 0 1-.125-1.262l.834-1.25a6.957 6.957 0 0 1-.587-1.416l-1.473-.295A1 1 0 0 1 1 10.68V9.32a1 1 0 0 1 .804-.98l1.473-.295c.144-.497.342-.971.587-1.416l-.834-1.25a1 1 0 0 1 .125-1.262l.962-.962A1 1 0 0 1 5.38 3.03l1.25.834a6.957 6.957 0 0 1 1.416-.587l.294-1.473ZM13 10a3 3 0 1 1-6 0 3 3 0 0 1 6 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
<button class="btn" onclick={newChat} disabled={messages.length === 0 && !isStreaming}>
|
||||
New Chat
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Settings panel -->
|
||||
{#if showSettings}
|
||||
<div class="shrink-0 mb-4 p-4 bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
<div class="mb-4">
|
||||
<label class="block text-sm font-medium mb-1" for="system-prompt">System Prompt</label>
|
||||
<textarea
|
||||
id="system-prompt"
|
||||
class="w-full px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-card focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
placeholder="You are a helpful assistant..."
|
||||
rows="3"
|
||||
bind:value={$systemPromptStore}
|
||||
disabled={isStreaming}
|
||||
></textarea>
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-sm font-medium mb-1" for="temperature">
|
||||
Temperature: {$temperatureStore.toFixed(2)}
|
||||
</label>
|
||||
<input
|
||||
id="temperature"
|
||||
type="range"
|
||||
min="0"
|
||||
max="2"
|
||||
step="0.05"
|
||||
class="w-full"
|
||||
bind:value={$temperatureStore}
|
||||
disabled={isStreaming}
|
||||
/>
|
||||
<div class="flex justify-between text-xs text-txtsecondary mt-1">
|
||||
<span>Precise (0)</span>
|
||||
<span>Creative (2)</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to start chatting.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Messages area -->
|
||||
<div
|
||||
class="flex-1 overflow-y-auto mb-4 px-2"
|
||||
bind:this={messagesContainer}
|
||||
onscroll={handleMessagesScroll}
|
||||
>
|
||||
{#if messages.length === 0}
|
||||
<div class="h-full flex items-center justify-center text-txtsecondary">
|
||||
<p>Start a conversation by typing a message below.</p>
|
||||
</div>
|
||||
{:else}
|
||||
{#each messages as message, idx (idx)}
|
||||
<ChatMessageComponent
|
||||
role={message.role}
|
||||
content={message.content}
|
||||
reasoning_content={message.reasoning_content}
|
||||
reasoningTimeMs={message.reasoningTimeMs}
|
||||
isStreaming={isStreaming && idx === messages.length - 1 && message.role === "assistant"}
|
||||
isReasoning={isReasoning && idx === messages.length - 1 && message.role === "assistant"}
|
||||
onEdit={message.role === "user" ? (newContent) => editMessage(idx, newContent) : undefined}
|
||||
onRegenerate={message.role === "assistant" && idx > 0 && messages[idx - 1].role === "user"
|
||||
? () => regenerateFromIndex(idx - 1)
|
||||
: undefined}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Input area -->
|
||||
<div class="shrink-0">
|
||||
<!-- Image preview strip -->
|
||||
{#if attachedImages.length > 0}
|
||||
<div class="mb-2 flex flex-wrap gap-2">
|
||||
{#each attachedImages as imageUrl, idx (idx)}
|
||||
<div class="relative group">
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Attached image {idx + 1}"
|
||||
class="w-20 h-20 object-cover rounded border border-gray-200 dark:border-white/10"
|
||||
/>
|
||||
<button
|
||||
class="absolute -top-2 -right-2 bg-red-500 text-white rounded-full w-6 h-6 flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onclick={() => removeImage(idx)}
|
||||
title="Remove image"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Error message -->
|
||||
{#if imageError}
|
||||
<div class="mb-2 p-2 bg-red-100 dark:bg-red-900/20 text-red-700 dark:text-red-400 rounded text-sm">
|
||||
{imageError}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="flex gap-2">
|
||||
<!-- Hidden file input -->
|
||||
<input
|
||||
type="file"
|
||||
accept=".jpg,.jpeg,.png,.gif,.webp"
|
||||
multiple
|
||||
class="hidden"
|
||||
bind:this={fileInput}
|
||||
onchange={handleImageSelect}
|
||||
/>
|
||||
|
||||
<ExpandableTextarea
|
||||
bind:value={userInput}
|
||||
placeholder="Type a message..."
|
||||
rows={3}
|
||||
onkeydown={handleKeyDown}
|
||||
disabled={isStreaming || !$selectedModelStore}
|
||||
/>
|
||||
<div class="flex flex-col gap-2">
|
||||
{#if isStreaming}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancelStreaming}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => fileInput?.click()}
|
||||
disabled={isStreaming || !$selectedModelStore}
|
||||
title="Attach image"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M1 5.25A2.25 2.25 0 0 1 3.25 3h13.5A2.25 2.25 0 0 1 19 5.25v9.5A2.25 2.25 0 0 1 16.75 17H3.25A2.25 2.25 0 0 1 1 14.75v-9.5Zm1.5 5.81v3.69c0 .414.336.75.75.75h13.5a.75.75 0 0 0 .75-.75v-2.69l-2.22-2.219a.75.75 0 0 0-1.06 0l-1.91 1.909.47.47a.75.75 0 1 1-1.06 1.06L6.53 8.091a.75.75 0 0 0-1.06 0l-2.97 2.97ZM12 7a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={sendMessage}
|
||||
disabled={(!userInput.trim() && attachedImages.length === 0) || !$selectedModelStore}
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,467 @@
|
||||
<script lang="ts">
|
||||
import { renderMarkdown, escapeHtml, renderStreamingMarkdown, createStreamingCache } from "../../lib/markdown";
|
||||
import type { RenderedBlock } from "../../lib/markdown";
|
||||
import { Copy, Check, Pencil, X, Save, RefreshCw, ChevronDown, ChevronRight, Brain, Code } from "lucide-svelte";
|
||||
import { getTextContent, getImageUrls } from "../../lib/types";
|
||||
import type { ContentPart } from "../../lib/types";
|
||||
|
||||
interface Props {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string | ContentPart[];
|
||||
reasoning_content?: string;
|
||||
reasoningTimeMs?: number;
|
||||
isStreaming?: boolean;
|
||||
isReasoning?: boolean;
|
||||
onEdit?: (newContent: string) => void;
|
||||
onRegenerate?: () => void;
|
||||
}
|
||||
|
||||
let { role, content, reasoning_content = "", reasoningTimeMs = 0, isStreaming = false, isReasoning = false, onEdit, onRegenerate }: Props = $props();
|
||||
|
||||
let textContent = $derived(getTextContent(content));
|
||||
let imageUrls = $derived(getImageUrls(content));
|
||||
let hasImages = $derived(imageUrls.length > 0);
|
||||
let canEdit = $derived(onEdit !== undefined && !hasImages);
|
||||
|
||||
let streamingCache = createStreamingCache();
|
||||
let renderedParts = $derived.by(() => {
|
||||
if (role !== "assistant") {
|
||||
return { blocks: [{ id: -1, html: escapeHtml(textContent).replace(/\n/g, '<br>') }] as RenderedBlock[], pendingHtml: "" };
|
||||
}
|
||||
if (!isStreaming) {
|
||||
streamingCache = createStreamingCache();
|
||||
return { blocks: [{ id: -1, html: renderMarkdown(textContent) }] as RenderedBlock[], pendingHtml: "" };
|
||||
}
|
||||
return renderStreamingMarkdown(textContent, streamingCache);
|
||||
});
|
||||
let copied = $state(false);
|
||||
let showRaw = $state(false);
|
||||
let isEditing = $state(false);
|
||||
let editContent = $state("");
|
||||
let showReasoning = $state(false);
|
||||
let modalImageUrl = $state<string | null>(null);
|
||||
|
||||
function formatDuration(ms: number): string {
|
||||
if (ms < 1000) {
|
||||
return `${ms.toFixed(0)}ms`;
|
||||
}
|
||||
return `${(ms / 1000).toFixed(1)}s`;
|
||||
}
|
||||
|
||||
async function copyToClipboard() {
|
||||
try {
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(textContent);
|
||||
} else {
|
||||
// Fallback for non-secure contexts (HTTP)
|
||||
const textarea = document.createElement("textarea");
|
||||
textarea.value = textContent;
|
||||
textarea.style.position = "fixed";
|
||||
textarea.style.left = "-9999px";
|
||||
document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
document.execCommand("copy");
|
||||
document.body.removeChild(textarea);
|
||||
}
|
||||
copied = true;
|
||||
setTimeout(() => (copied = false), 2000);
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
}
|
||||
}
|
||||
|
||||
function startEdit() {
|
||||
editContent = textContent;
|
||||
isEditing = true;
|
||||
}
|
||||
|
||||
function cancelEdit() {
|
||||
isEditing = false;
|
||||
editContent = "";
|
||||
}
|
||||
|
||||
function saveEdit() {
|
||||
if (onEdit && editContent.trim() !== textContent) {
|
||||
onEdit(editContent.trim());
|
||||
}
|
||||
isEditing = false;
|
||||
editContent = "";
|
||||
}
|
||||
|
||||
function openModal(imageUrl: string) {
|
||||
modalImageUrl = imageUrl;
|
||||
document.body.style.overflow = "hidden";
|
||||
}
|
||||
|
||||
function closeModal(event?: MouseEvent) {
|
||||
// Only close if clicking the background, not the image
|
||||
if (event && event.target !== event.currentTarget) {
|
||||
return;
|
||||
}
|
||||
modalImageUrl = null;
|
||||
document.body.style.overflow = "";
|
||||
}
|
||||
|
||||
function handleModalKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Escape") {
|
||||
closeModal();
|
||||
}
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
saveEdit();
|
||||
} else if (event.key === "Escape") {
|
||||
cancelEdit();
|
||||
}
|
||||
}
|
||||
|
||||
const COPY_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>`;
|
||||
const CHECK_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M20 6 9 17l-5-5"/></svg>`;
|
||||
|
||||
function codeBlockCopy(node: HTMLElement) {
|
||||
function attachButtons() {
|
||||
node.querySelectorAll<HTMLPreElement>('pre:not([data-copy-btn])').forEach(pre => {
|
||||
pre.setAttribute('data-copy-btn', 'true');
|
||||
const btn = document.createElement('button');
|
||||
btn.className = 'code-copy-btn';
|
||||
btn.title = 'Copy code';
|
||||
btn.innerHTML = COPY_SVG;
|
||||
btn.addEventListener('click', async () => {
|
||||
const text = pre.querySelector('code')?.textContent ?? pre.textContent ?? '';
|
||||
try {
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} else {
|
||||
const ta = document.createElement('textarea');
|
||||
ta.value = text;
|
||||
ta.style.cssText = 'position:fixed;left:-9999px';
|
||||
document.body.appendChild(ta);
|
||||
ta.select();
|
||||
document.execCommand('copy');
|
||||
document.body.removeChild(ta);
|
||||
}
|
||||
btn.innerHTML = CHECK_SVG;
|
||||
btn.classList.add('copied');
|
||||
setTimeout(() => { btn.innerHTML = COPY_SVG; btn.classList.remove('copied'); }, 2000);
|
||||
} catch (e) {
|
||||
console.error('copy failed', e);
|
||||
}
|
||||
});
|
||||
pre.appendChild(btn);
|
||||
});
|
||||
}
|
||||
attachButtons();
|
||||
const mo = new MutationObserver(attachButtons);
|
||||
mo.observe(node, { childList: true, subtree: true });
|
||||
return { destroy: () => mo.disconnect() };
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex {role === 'user' ? 'justify-end' : 'justify-start'} mb-4">
|
||||
<div
|
||||
class="relative group rounded-lg px-4 py-2 {role === 'user'
|
||||
? 'max-w-[85%] bg-primary text-btn-primary-text'
|
||||
: 'w-full sm:w-4/5 bg-surface border border-gray-200 dark:border-white/10'}"
|
||||
>
|
||||
{#if role === "assistant"}
|
||||
{#if reasoning_content || isReasoning}
|
||||
<div class="mb-3 border border-gray-200 dark:border-white/10 rounded overflow-hidden">
|
||||
<button
|
||||
class="w-full flex items-center gap-2 px-3 py-2 bg-gray-50 dark:bg-white/5 hover:bg-gray-100 dark:hover:bg-white/10 transition-colors text-sm"
|
||||
onclick={() => showReasoning = !showReasoning}
|
||||
>
|
||||
{#if showReasoning}
|
||||
<ChevronDown class="w-4 h-4" />
|
||||
{:else}
|
||||
<ChevronRight class="w-4 h-4" />
|
||||
{/if}
|
||||
<Brain class="w-4 h-4" />
|
||||
<span class="font-medium">Reasoning</span>
|
||||
<span class="text-txtsecondary ml-2">
|
||||
({reasoning_content.length} chars{#if !isReasoning && reasoningTimeMs > 0}, {formatDuration(reasoningTimeMs)}{/if})
|
||||
</span>
|
||||
{#if isReasoning}
|
||||
<span class="ml-auto flex items-center gap-1 text-txtsecondary">
|
||||
<span class="w-1.5 h-1.5 bg-primary rounded-full animate-pulse"></span>
|
||||
reasoning...
|
||||
</span>
|
||||
{/if}
|
||||
</button>
|
||||
{#if showReasoning}
|
||||
<div class="px-3 py-2 bg-gray-50/50 dark:bg-white/[0.02] text-sm text-txtsecondary whitespace-pre-wrap font-mono">
|
||||
{reasoning_content}{#if isReasoning}<span class="inline-block w-1.5 h-4 bg-current animate-pulse ml-0.5"></span>{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{#if hasImages}
|
||||
<div class="mb-3 flex flex-wrap gap-2">
|
||||
{#each imageUrls as imageUrl, idx (idx)}
|
||||
<button
|
||||
onclick={() => openModal(imageUrl)}
|
||||
class="cursor-pointer rounded border border-gray-200 dark:border-white/10 hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Image {idx + 1}"
|
||||
class="max-h-64 rounded"
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{#if showRaw}
|
||||
<div class="whitespace-pre-wrap font-mono text-sm">{textContent}</div>
|
||||
{:else}
|
||||
<div class="prose prose-sm dark:prose-invert max-w-none" use:codeBlockCopy>
|
||||
{#each renderedParts.blocks as block (block.id)}
|
||||
{@html block.html}
|
||||
{/each}
|
||||
{@html renderedParts.pendingHtml}
|
||||
{#if isStreaming && !isReasoning}
|
||||
<span class="inline-block w-2 h-4 bg-current animate-pulse ml-0.5"></span>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{#if !isStreaming}
|
||||
<div class="flex gap-1 mt-2 pt-1 border-t border-gray-200 dark:border-white/10">
|
||||
{#if onRegenerate}
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 text-txtsecondary"
|
||||
onclick={onRegenerate}
|
||||
title="Regenerate response"
|
||||
>
|
||||
<RefreshCw class="w-4 h-4" />
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 text-txtsecondary"
|
||||
onclick={copyToClipboard}
|
||||
title={copied ? "Copied!" : "Copy to clipboard"}
|
||||
>
|
||||
{#if copied}
|
||||
<Check class="w-4 h-4 text-green-500" />
|
||||
{:else}
|
||||
<Copy class="w-4 h-4" />
|
||||
{/if}
|
||||
</button>
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 {showRaw ? 'text-primary' : 'text-txtsecondary'}"
|
||||
onclick={() => showRaw = !showRaw}
|
||||
title={showRaw ? "Show rendered" : "Show raw"}
|
||||
>
|
||||
<Code class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
{:else}
|
||||
{#if isEditing}
|
||||
<div class="flex flex-col gap-2 min-w-[300px]">
|
||||
<textarea
|
||||
class="w-full px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface text-txtmain focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
rows="3"
|
||||
bind:value={editContent}
|
||||
onkeydown={handleKeyDown}
|
||||
></textarea>
|
||||
<div class="flex justify-end gap-2">
|
||||
<button
|
||||
class="p-1.5 rounded hover:bg-white/20"
|
||||
onclick={cancelEdit}
|
||||
title="Cancel"
|
||||
>
|
||||
<X class="w-4 h-4" />
|
||||
</button>
|
||||
<button
|
||||
class="p-1.5 rounded hover:bg-white/20"
|
||||
onclick={saveEdit}
|
||||
title="Save"
|
||||
>
|
||||
<Save class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
{#if hasImages}
|
||||
<div class="mb-2 flex flex-wrap gap-2">
|
||||
{#each imageUrls as imageUrl, idx (idx)}
|
||||
<button
|
||||
onclick={() => openModal(imageUrl)}
|
||||
class="cursor-pointer rounded border border-white/20 hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Image {idx + 1}"
|
||||
class="max-w-[200px] rounded"
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="whitespace-pre-wrap pr-8">{textContent}</div>
|
||||
{#if canEdit}
|
||||
<button
|
||||
class="absolute top-2 right-2 p-1.5 rounded-lg opacity-0 group-hover:opacity-100 transition-opacity bg-white/20 hover:bg-white/30 shadow-sm"
|
||||
onclick={startEdit}
|
||||
title="Edit message"
|
||||
>
|
||||
<Pencil class="w-4 h-4" />
|
||||
</button>
|
||||
{/if}
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Full-size image modal -->
|
||||
{#if modalImageUrl}
|
||||
<div
|
||||
class="fixed inset-0 z-50 flex items-center justify-center bg-black/80 p-4"
|
||||
onclick={(e) => closeModal(e)}
|
||||
onkeydown={handleModalKeyDown}
|
||||
role="button"
|
||||
tabindex="-1"
|
||||
>
|
||||
<button
|
||||
class="absolute top-4 right-4 p-2 rounded-lg bg-white/10 hover:bg-white/20 text-white transition-colors"
|
||||
onclick={() => closeModal()}
|
||||
title="Close"
|
||||
>
|
||||
<X class="w-6 h-6" />
|
||||
</button>
|
||||
<img
|
||||
src={modalImageUrl}
|
||||
alt=""
|
||||
class="max-w-full max-h-full rounded pointer-events-none"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.prose :global(pre) {
|
||||
position: relative;
|
||||
background-color: var(--color-surface);
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
border-radius: 0.375rem;
|
||||
padding: 0.75rem;
|
||||
padding-right: 2.5rem;
|
||||
overflow-x: auto;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(.code-copy-btn) {
|
||||
position: absolute;
|
||||
top: 0.375rem;
|
||||
right: 0.375rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-border);
|
||||
background: var(--color-surface);
|
||||
color: var(--color-txtsecondary);
|
||||
cursor: pointer;
|
||||
transition: background-color 0.15s;
|
||||
line-height: 0;
|
||||
}
|
||||
|
||||
.prose :global(.code-copy-btn:hover) {
|
||||
background: var(--color-secondary);
|
||||
}
|
||||
|
||||
.prose :global(.code-copy-btn.copied) {
|
||||
color: var(--color-success);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.prose :global(code) {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
.prose :global(pre code) {
|
||||
background: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.prose :global(code:not(pre code)) {
|
||||
background-color: var(--color-surface);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
}
|
||||
|
||||
.prose :global(p) {
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(p:first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.prose :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.prose :global(ul),
|
||||
.prose :global(ol) {
|
||||
margin: 0.5rem 0;
|
||||
padding-left: 1.5rem;
|
||||
}
|
||||
|
||||
.prose :global(li) {
|
||||
margin: 0.25rem 0;
|
||||
}
|
||||
|
||||
.prose :global(h1),
|
||||
.prose :global(h2),
|
||||
.prose :global(h3),
|
||||
.prose :global(h4) {
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.prose :global(h1:first-child),
|
||||
.prose :global(h2:first-child),
|
||||
.prose :global(h3:first-child),
|
||||
.prose :global(h4:first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.prose :global(blockquote) {
|
||||
border-left: 3px solid var(--color-primary);
|
||||
padding-left: 1rem;
|
||||
margin: 0.5rem 0;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.prose :global(a) {
|
||||
color: var(--color-primary);
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.prose :global(table) {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(th),
|
||||
.prose :global(td) {
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.prose :global(th) {
|
||||
background-color: var(--color-surface);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
/* Highlight.js theme overrides for dark mode */
|
||||
:global(.dark) .prose :global(.hljs) {
|
||||
background: transparent;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,121 @@
|
||||
<script lang="ts">
|
||||
import { untrack } from "svelte";
|
||||
import { Maximize2, X } from "lucide-svelte";
|
||||
|
||||
interface Props {
|
||||
value: string;
|
||||
placeholder?: string;
|
||||
rows?: number;
|
||||
disabled?: boolean;
|
||||
onkeydown?: (event: KeyboardEvent) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
value = $bindable(),
|
||||
placeholder = "",
|
||||
rows = 3,
|
||||
disabled = false,
|
||||
onkeydown,
|
||||
}: Props = $props();
|
||||
|
||||
let isExpanded = $state(false);
|
||||
let expandedValue = $state("");
|
||||
let expandedTextarea: HTMLTextAreaElement | undefined = $state();
|
||||
|
||||
function openExpanded() {
|
||||
expandedValue = value;
|
||||
isExpanded = true;
|
||||
}
|
||||
|
||||
function closeExpanded() {
|
||||
isExpanded = false;
|
||||
}
|
||||
|
||||
function saveExpanded() {
|
||||
value = expandedValue;
|
||||
isExpanded = false;
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Escape") {
|
||||
closeExpanded();
|
||||
}
|
||||
}
|
||||
|
||||
// Focus the textarea when expanded view opens
|
||||
$effect(() => {
|
||||
if (isExpanded && expandedTextarea) {
|
||||
expandedTextarea.focus();
|
||||
const len = untrack(() => expandedValue.length);
|
||||
expandedTextarea.setSelectionRange(len, len);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex-1 relative group flex items-stretch min-h-0">
|
||||
<textarea
|
||||
class="w-full px-3 py-2 pr-10 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-inset focus:ring-primary resize-none"
|
||||
{placeholder}
|
||||
{rows}
|
||||
bind:value
|
||||
{onkeydown}
|
||||
{disabled}
|
||||
></textarea>
|
||||
<button
|
||||
class="absolute top-2 right-2 p-1.5 rounded-lg opacity-60 md:opacity-0 group-hover:opacity-100 transition-opacity bg-surface/90 hover:bg-surface border border-gray-200 dark:border-white/10 shadow-sm"
|
||||
onclick={openExpanded}
|
||||
title="Expand to edit"
|
||||
type="button"
|
||||
{disabled}
|
||||
>
|
||||
<Maximize2 class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{#if isExpanded}
|
||||
<div class="fixed inset-0 z-50 flex items-center justify-center bg-black/50 p-4">
|
||||
<div class="w-full max-w-4xl h-[80vh] flex flex-col bg-surface rounded-lg shadow-xl border border-gray-200 dark:border-white/10">
|
||||
<!-- Header -->
|
||||
<div class="flex justify-between items-center p-4 border-b border-gray-200 dark:border-white/10">
|
||||
<h3 class="font-medium">Edit Text</h3>
|
||||
<button
|
||||
class="p-1.5 rounded-lg hover:bg-gray-100 dark:hover:bg-white/10"
|
||||
onclick={closeExpanded}
|
||||
title="Close"
|
||||
type="button"
|
||||
>
|
||||
<X class="w-5 h-5" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Textarea -->
|
||||
<div class="flex-1 p-4">
|
||||
<textarea
|
||||
bind:this={expandedTextarea}
|
||||
class="w-full h-full px-4 py-3 rounded border border-gray-200 dark:border-white/10 bg-card focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
placeholder={placeholder}
|
||||
bind:value={expandedValue}
|
||||
onkeydown={handleKeyDown}
|
||||
></textarea>
|
||||
</div>
|
||||
|
||||
<!-- Footer -->
|
||||
<div class="flex justify-end gap-2 p-4 border-t border-gray-200 dark:border-white/10">
|
||||
<button
|
||||
class="btn"
|
||||
onclick={closeExpanded}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={saveExpanded}
|
||||
type="button"
|
||||
>
|
||||
Done
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -0,0 +1,521 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { generateImage } from "../../lib/imageApi";
|
||||
import { generateSdImage, fetchSdLoras } from "../../lib/sdApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
import ExpandableTextarea from "./ExpandableTextarea.svelte";
|
||||
import type { ImageApiMode, SdApiLora, SdApiLoraRef } from "../../lib/types";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-image-model", "");
|
||||
const selectedSizeStore = persistentStore<string>("playground-image-size", "1024x1024");
|
||||
const apiModeStore = persistentStore<ImageApiMode>("playground-image-api-mode", "openai");
|
||||
|
||||
// SDAPI persistent settings
|
||||
const sdNegativePromptStore = persistentStore<string>("playground-sdapi-negative-prompt", "");
|
||||
const sdStepsStore = persistentStore<number>("playground-sdapi-steps", 20);
|
||||
const sdCfgScaleStore = persistentStore<number>("playground-sdapi-cfg-scale", 7);
|
||||
const sdSeedStore = persistentStore<number>("playground-sdapi-seed", -1);
|
||||
const sdSamplerStore = persistentStore<string>("playground-sdapi-sampler", "");
|
||||
const sdSchedulerStore = persistentStore<string>("playground-sdapi-scheduler", "");
|
||||
const sdBatchSizeStore = persistentStore<number>("playground-sdapi-batch-size", 1);
|
||||
|
||||
let prompt = $state("");
|
||||
let isGenerating = $state(false);
|
||||
let generatedImages = $state<string[]>([]);
|
||||
let error = $state<string | null>(null);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let showFullscreen = $state(false);
|
||||
let fullscreenIndex = $state(0);
|
||||
let showSettings = $state(false);
|
||||
|
||||
// SDAPI lora state
|
||||
let availableLoras = $state<SdApiLora[]>([]);
|
||||
let selectedLoras = $state<SdApiLoraRef[]>([]);
|
||||
let isLoadingLoras = $state(false);
|
||||
let lorasLoaded = $state(false);
|
||||
let loraError = $state<string | null>(null);
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
let isSdapi = $derived($apiModeStore === "sdapi");
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.imageGenerating.set(isGenerating);
|
||||
});
|
||||
|
||||
async function loadLoras() {
|
||||
if (!$selectedModelStore || isLoadingLoras) return;
|
||||
isLoadingLoras = true;
|
||||
loraError = null;
|
||||
try {
|
||||
const loras = await fetchSdLoras($selectedModelStore);
|
||||
availableLoras = loras;
|
||||
lorasLoaded = true;
|
||||
} catch (err) {
|
||||
availableLoras = [];
|
||||
loraError = err instanceof Error ? err.message : "Failed to load LoRAs";
|
||||
lorasLoaded = false;
|
||||
} finally {
|
||||
isLoadingLoras = false;
|
||||
}
|
||||
}
|
||||
|
||||
function addLora(event: Event) {
|
||||
const select = event.target as HTMLSelectElement;
|
||||
const path = select.value;
|
||||
if (!path) return;
|
||||
|
||||
const lora = availableLoras.find((l) => l.path === path);
|
||||
if (lora && !selectedLoras.some((l) => l.path === path)) {
|
||||
selectedLoras = [...selectedLoras, { path: lora.path, multiplier: 1.0 }];
|
||||
}
|
||||
select.value = "";
|
||||
}
|
||||
|
||||
function removeLora(path: string) {
|
||||
selectedLoras = selectedLoras.filter((l) => l.path !== path);
|
||||
}
|
||||
|
||||
function updateLoraMultiplier(path: string, multiplier: number) {
|
||||
selectedLoras = selectedLoras.map((l) =>
|
||||
l.path === path ? { ...l, multiplier } : l
|
||||
);
|
||||
}
|
||||
|
||||
function getLoraName(path: string): string {
|
||||
return availableLoras.find((l) => l.path === path)?.name ?? path;
|
||||
}
|
||||
|
||||
async function generate() {
|
||||
const trimmedPrompt = prompt.trim();
|
||||
if (!trimmedPrompt || !$selectedModelStore || isGenerating) return;
|
||||
|
||||
isGenerating = true;
|
||||
error = null;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
if (isSdapi) {
|
||||
const [w, h] = $selectedSizeStore.split("x").map(Number);
|
||||
const request = {
|
||||
model: $selectedModelStore,
|
||||
prompt: trimmedPrompt,
|
||||
negative_prompt: $sdNegativePromptStore || undefined,
|
||||
width: w,
|
||||
height: h,
|
||||
steps: $sdStepsStore,
|
||||
cfg_scale: $sdCfgScaleStore,
|
||||
seed: $sdSeedStore,
|
||||
batch_size: $sdBatchSizeStore,
|
||||
sampler_name: $sdSamplerStore || undefined,
|
||||
scheduler: $sdSchedulerStore || undefined,
|
||||
lora: selectedLoras.length > 0 ? selectedLoras : undefined,
|
||||
};
|
||||
|
||||
const response = await generateSdImage(request, abortController.signal);
|
||||
if (response.images && response.images.length > 0) {
|
||||
generatedImages = response.images.map(
|
||||
(img) => `data:image/png;base64,${img}`
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const response = await generateImage(
|
||||
$selectedModelStore,
|
||||
trimmedPrompt,
|
||||
$selectedSizeStore,
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
if (response.data && response.data.length > 0) {
|
||||
const imageData = response.data[0];
|
||||
if (imageData.b64_json) {
|
||||
generatedImages = [`data:image/png;base64,${imageData.b64_json}`];
|
||||
} else if (imageData.url) {
|
||||
generatedImages = [imageData.url];
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// User cancelled
|
||||
} else {
|
||||
error = err instanceof Error ? err.message : "An error occurred";
|
||||
}
|
||||
} finally {
|
||||
isGenerating = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function cancelGeneration() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function clearImage() {
|
||||
generatedImages = [];
|
||||
error = null;
|
||||
prompt = "";
|
||||
}
|
||||
|
||||
function downloadImage(index: number = 0) {
|
||||
const img = generatedImages[index];
|
||||
if (!img) return;
|
||||
|
||||
const link = document.createElement("a");
|
||||
link.href = img;
|
||||
link.download = `generated-image-${Date.now()}-${index}.png`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
}
|
||||
|
||||
function openFullscreen(index: number = 0) {
|
||||
fullscreenIndex = index;
|
||||
showFullscreen = true;
|
||||
}
|
||||
|
||||
function closeFullscreen(event?: MouseEvent) {
|
||||
if (event && event.target !== event.currentTarget) {
|
||||
return;
|
||||
}
|
||||
showFullscreen = false;
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
generate();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector and mode toggle -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} />
|
||||
|
||||
<select
|
||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$apiModeStore}
|
||||
disabled={isGenerating}
|
||||
>
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="sdapi">SDAPI</option>
|
||||
</select>
|
||||
|
||||
<select
|
||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$selectedSizeStore}
|
||||
disabled={isGenerating}
|
||||
>
|
||||
<optgroup label="Square">
|
||||
<option value="512x512">512x512</option>
|
||||
<option value="1024x1024">1024x1024</option>
|
||||
</optgroup>
|
||||
<optgroup label="Landscape">
|
||||
<option value="1024x768">1024x768 (4:3)</option>
|
||||
<option value="1280x720">1280x720 (16:9)</option>
|
||||
<option value="1792x1024">1792x1024 (SDXL)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Portrait">
|
||||
<option value="768x1024">768x1024 (3:4)</option>
|
||||
<option value="720x1280">720x1280 (9:16)</option>
|
||||
<option value="1024x1792">1024x1792 (SDXL)</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
|
||||
{#if isSdapi}
|
||||
<button
|
||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface hover:bg-secondary-hover transition-colors"
|
||||
onclick={() => showSettings = !showSettings}
|
||||
>
|
||||
{showSettings ? "Hide Settings" : "Settings"}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- SDAPI Settings Panel -->
|
||||
{#if isSdapi && showSettings}
|
||||
<div class="shrink-0 mb-4 p-4 rounded border border-gray-200 dark:border-white/10 bg-surface">
|
||||
<div class="grid grid-cols-2 md:grid-cols-4 gap-3 mb-3">
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Steps</span>
|
||||
<input
|
||||
type="number"
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdStepsStore}
|
||||
min="1"
|
||||
max="150"
|
||||
/>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">CFG Scale</span>
|
||||
<input
|
||||
type="number"
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdCfgScaleStore}
|
||||
min="1"
|
||||
max="30"
|
||||
step="0.5"
|
||||
/>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Seed (-1 = random)</span>
|
||||
<input
|
||||
type="number"
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdSeedStore}
|
||||
min="-1"
|
||||
/>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Batch Size</span>
|
||||
<input
|
||||
type="number"
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdBatchSizeStore}
|
||||
min="1"
|
||||
max="8"
|
||||
/>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Sampler</span>
|
||||
<select
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdSamplerStore}
|
||||
>
|
||||
<option value="">Default</option>
|
||||
<option value="euler_a">euler_a</option>
|
||||
<option value="euler">euler</option>
|
||||
<option value="heun">heun</option>
|
||||
<option value="dpm2">dpm2</option>
|
||||
<option value="dpmpp2s_a">dpmpp2s_a</option>
|
||||
<option value="dpmpp2m">dpmpp2m</option>
|
||||
<option value="dpmpp2mv2">dpmpp2mv2</option>
|
||||
<option value="ipndm">ipndm</option>
|
||||
<option value="ipndm_v">ipndm_v</option>
|
||||
<option value="lcm">lcm</option>
|
||||
<option value="ddim_trailing">ddim_trailing</option>
|
||||
<option value="tcd">tcd</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Scheduler</span>
|
||||
<select
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdSchedulerStore}
|
||||
>
|
||||
<option value="">Auto for model</option>
|
||||
<option value="discrete">discrete</option>
|
||||
<option value="karras">karras</option>
|
||||
<option value="exponential">exponential</option>
|
||||
<option value="ays">ays</option>
|
||||
<option value="gits">gits</option>
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label class="flex flex-col gap-1 mb-3">
|
||||
<span class="text-xs text-txtsecondary">Negative Prompt</span>
|
||||
<textarea
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary resize-y text-sm"
|
||||
bind:value={$sdNegativePromptStore}
|
||||
rows="2"
|
||||
placeholder="Elements to avoid..."
|
||||
></textarea>
|
||||
</label>
|
||||
|
||||
<!-- LoRA Selection -->
|
||||
<div>
|
||||
<span class="text-xs text-txtsecondary block mb-1">LoRAs</span>
|
||||
<div class="flex items-center gap-2 mb-2">
|
||||
<button
|
||||
class="px-3 py-1.5 text-sm rounded border border-gray-200 dark:border-white/10 bg-surface hover:bg-secondary-hover transition-colors disabled:opacity-50"
|
||||
onclick={loadLoras}
|
||||
disabled={!$selectedModelStore || isLoadingLoras}
|
||||
>
|
||||
{isLoadingLoras ? "Loading..." : lorasLoaded ? "Reload LoRAs" : "Load LoRAs"}
|
||||
</button>
|
||||
{#if lorasLoaded && availableLoras.length > 0}
|
||||
<select
|
||||
class="flex-1 px-2 py-1.5 text-sm rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
onchange={addLora}
|
||||
>
|
||||
<option value="">Add a LoRA...</option>
|
||||
{#each availableLoras.filter((l) => !selectedLoras.some((s) => s.path === l.path)) as lora}
|
||||
<option value={lora.path}>{lora.name}</option>
|
||||
{/each}
|
||||
</select>
|
||||
{/if}
|
||||
</div>
|
||||
{#if loraError}
|
||||
<p class="text-xs text-red-500 mb-1">{loraError}</p>
|
||||
{/if}
|
||||
{#if lorasLoaded && availableLoras.length === 0}
|
||||
<p class="text-xs text-txtsecondary">No LoRAs available</p>
|
||||
{/if}
|
||||
{#if selectedLoras.length > 0}
|
||||
<div class="flex flex-col gap-1.5">
|
||||
{#each selectedLoras as lora}
|
||||
<div class="flex items-center gap-2 text-sm">
|
||||
<span class="flex-1 truncate">{getLoraName(lora.path)}</span>
|
||||
<input
|
||||
type="number"
|
||||
class="w-20 px-1.5 py-1 text-xs rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-1 focus:ring-primary"
|
||||
value={lora.multiplier}
|
||||
oninput={(e) => updateLoraMultiplier(lora.path, parseFloat((e.target as HTMLInputElement).value) || 1)}
|
||||
min="0"
|
||||
max="2"
|
||||
step="0.1"
|
||||
/>
|
||||
<button
|
||||
class="px-1.5 py-0.5 text-xs rounded border border-gray-200 dark:border-white/10 hover:bg-red-500 hover:text-white hover:border-red-500 transition-colors"
|
||||
onclick={() => removeLora(lora.path)}
|
||||
aria-label="Remove LoRA"
|
||||
>
|
||||
x
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to generate images.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Image display area -->
|
||||
<div class="flex-1 overflow-auto mb-4 flex items-center justify-center bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
{#if isGenerating}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
|
||||
<p>Generating image...</p>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="text-center text-red-500 p-4">
|
||||
<p class="font-medium">Error</p>
|
||||
<p class="text-sm mt-1">{error}</p>
|
||||
</div>
|
||||
{:else if generatedImages.length > 1}
|
||||
<!-- Grid for multiple images (batch) -->
|
||||
<div class="grid grid-cols-2 gap-2 p-2 w-full h-full overflow-auto">
|
||||
{#each generatedImages as img, i}
|
||||
<div class="relative flex items-center justify-center">
|
||||
<button
|
||||
class="p-0 border-0 bg-transparent cursor-pointer"
|
||||
onclick={() => openFullscreen(i)}
|
||||
aria-label="View fullscreen"
|
||||
>
|
||||
<img
|
||||
src={img}
|
||||
alt="AI generated content {i + 1}"
|
||||
class="max-w-full max-h-full object-contain hover:opacity-90 transition-opacity"
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
class="absolute bottom-2 right-2 p-1.5 bg-black/60 hover:bg-black/80 text-white rounded-full transition-colors"
|
||||
onclick={(e) => { e.stopPropagation(); downloadImage(i); }}
|
||||
aria-label="Download image"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{:else if generatedImages.length === 1}
|
||||
<div class="relative max-w-full max-h-full flex items-center justify-center">
|
||||
<button
|
||||
class="p-0 border-0 bg-transparent cursor-pointer"
|
||||
onclick={() => openFullscreen(0)}
|
||||
aria-label="View fullscreen"
|
||||
>
|
||||
<img
|
||||
src={generatedImages[0]}
|
||||
alt="AI generated content"
|
||||
class="max-w-full max-h-full object-contain hover:opacity-90 transition-opacity"
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
class="absolute bottom-2 right-2 p-2 bg-black/60 hover:bg-black/80 text-white rounded-full transition-colors"
|
||||
onclick={(e) => { e.stopPropagation(); downloadImage(0); }}
|
||||
aria-label="Download image"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<p>Enter a prompt below to generate an image</p>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Prompt input area -->
|
||||
<div class="shrink-0 flex flex-col md:flex-row gap-2">
|
||||
<ExpandableTextarea
|
||||
bind:value={prompt}
|
||||
placeholder="Describe the image you want to generate..."
|
||||
rows={3}
|
||||
onkeydown={handleKeyDown}
|
||||
disabled={isGenerating || !$selectedModelStore}
|
||||
/>
|
||||
<div class="flex flex-row md:flex-col gap-2">
|
||||
{#if isGenerating}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white flex-1 md:flex-none" onclick={cancelGeneration}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90 flex-1 md:flex-none"
|
||||
onclick={generate}
|
||||
disabled={!prompt.trim() || !$selectedModelStore}
|
||||
>
|
||||
Generate
|
||||
</button>
|
||||
<button
|
||||
class="btn flex-1 md:flex-none"
|
||||
onclick={clearImage}
|
||||
disabled={generatedImages.length === 0 && !error && !prompt.trim()}
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Fullscreen dialog -->
|
||||
{#if showFullscreen && generatedImages[fullscreenIndex]}
|
||||
<div
|
||||
class="fixed inset-0 bg-black/90 z-50 flex items-center justify-center p-4"
|
||||
onclick={(e) => closeFullscreen(e)}
|
||||
onkeydown={(e) => e.key === 'Escape' && closeFullscreen()}
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
tabindex="-1"
|
||||
>
|
||||
<button
|
||||
class="absolute top-4 right-4 text-white hover:text-gray-300 text-2xl w-10 h-10 flex items-center justify-center rounded-full hover:bg-white/10 transition-colors"
|
||||
onclick={() => closeFullscreen()}
|
||||
aria-label="Close fullscreen"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
<img
|
||||
src={generatedImages[fullscreenIndex]}
|
||||
alt="AI generated content"
|
||||
class="max-w-full max-h-full object-contain pointer-events-none"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||