Compare commits
88 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d87f0ce2c5 | |||
| 06bc6a614c | |||
| a37b4866d8 | |||
| 981910d734 | |||
| a185efe37e | |||
| 1dd1aadf93 | |||
| 955900972a | |||
| c2c8cfaf81 | |||
| 1e440770ea | |||
| c794273c83 | |||
| 6574a52cbb | |||
| 8fabc75634 | |||
| e5e7391b6d | |||
| 2c282dccad | |||
| 916d13f5bd | |||
| a3725e7d09 | |||
| 15bd55d3a9 | |||
| c3c258a55d | |||
| 29a38fde0d | |||
| d569681daa | |||
| 24efdb76b1 | |||
| cc77139ff8 | |||
| 390a35bf93 | |||
| 181f71ca11 | |||
| 49546e2cf2 | |||
| 2c078964f4 | |||
| 175bb36fb1 | |||
| aedb640471 | |||
| 2f377f6dc6 | |||
| 64e4c79fc3 | |||
| 19fb5f35e9 | |||
| b45102bde8 | |||
| 1688bdd1e9 | |||
| d33d51fa75 | |||
| e3bf065574 | |||
| 3e52144058 | |||
| d5e52d7d00 | |||
| 17e5263a76 | |||
| 8d6d949ec3 | |||
| b5fde8eb6d | |||
| 7eef5defb8 | |||
| bc01e6f539 | |||
| 0462e3dc3f | |||
| 7b20fc011b | |||
| 20738f3623 | |||
| cdea7d16bd | |||
| 5de387dbf9 | |||
| 6f8e7ccb57 | |||
| 4384315b44 | |||
| 6439ab1515 | |||
| f94226122c | |||
| 7493618fdc | |||
| 205efd40a1 | |||
| 14207f8492 | |||
| 4e850c2834 | |||
| 75fced579e | |||
| b73f367f22 | |||
| 8f2137c72b | |||
| 124007cc98 | |||
| eb5bfff0b0 | |||
| 3edb180c08 | |||
| 66d555e625 | |||
| 4f863fd9fc | |||
| 267c030457 | |||
| c19309fe7e | |||
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f | |||
| 37d74efc2d | |||
| 22e098ac8b | |||
| 9864f9f517 | |||
| 53b32f3601 | |||
| 565c44766d | |||
| e6a9e210ba | |||
| d3f329f924 | |||
| 98879b38c1 | |||
| 7b3b0f5eae | |||
| 021ccceef1 | |||
| f03871c50a | |||
| dc00d17abe | |||
| dea98733c3 | |||
| bccce5fa19 | |||
| c968da1b73 | |||
| a883d68d4f | |||
| b1dec8b735 | |||
| 06523d8c1e |
@@ -4,12 +4,19 @@ early_access: false
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: true
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
issue_enrichment:
|
||||
planning:
|
||||
enabled: false
|
||||
|
||||
@@ -4,11 +4,15 @@ on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
- "config.example.yaml"
|
||||
- ".github/workflows/config-schema.yml"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
- "config.example.yaml"
|
||||
- ".github/workflows/config-schema.yml"
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -39,3 +43,14 @@ jobs:
|
||||
fi
|
||||
|
||||
echo "✓ config-schema.json is valid"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Install check-jsonschema
|
||||
run: pip install check-jsonschema
|
||||
|
||||
- name: Validate config.example.yaml against schema
|
||||
run: check-jsonschema --schemafile config-schema.json config.example.yaml
|
||||
|
||||
@@ -10,17 +10,44 @@ on:
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on workflow file changes (without pushing)
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/containers.yml'
|
||||
- 'docker/build-container.sh'
|
||||
- 'docker/*.Containerfile'
|
||||
|
||||
# grant permissions on GITHUB_TOKEN to publish packages
|
||||
# ref: https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#publishing-a-package-using-an-action
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
platform: [intel, cuda, cuda13, vulkan, cpu, musa, rocm]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
if: matrix.platform == 'rocm'
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -31,7 +58,7 @@ jobs:
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
|
||||
@@ -3,9 +3,25 @@ name: Windows CI
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
@@ -28,7 +44,7 @@ jobs:
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
@@ -43,7 +59,7 @@ jobs:
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
shell: bash
|
||||
|
||||
@@ -3,9 +3,25 @@ name: Linux CI
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
@@ -20,7 +36,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
go-version-file: go.mod
|
||||
|
||||
# Only run in this linux based runner
|
||||
- name: Check Formatting
|
||||
@@ -35,7 +51,7 @@ jobs:
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
@@ -51,4 +67,4 @@ jobs:
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
run: make test-all
|
||||
|
||||
@@ -3,13 +3,13 @@ name: goreleaser
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag version to release (e.g. v144)'
|
||||
description: "Tag version to release (e.g. v144)"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
@@ -19,35 +19,30 @@ jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
-
|
||||
name: Set up Go
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
-
|
||||
name: Set up Node.js
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '23'
|
||||
-
|
||||
name: Install dependencies and build UI
|
||||
node-version: "24"
|
||||
- name: Install dependencies and build UI
|
||||
run: |
|
||||
cd ui
|
||||
cd ui-svelte
|
||||
npm ci
|
||||
npm run build
|
||||
|
||||
-
|
||||
name: Run GoReleaser
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
# either 'goreleaser' (default) or 'goreleaser-pro'
|
||||
distribution: goreleaser
|
||||
# 'latest', 'nightly', or a semver
|
||||
version: '~> v2'
|
||||
version: "~> v2"
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -76,4 +71,4 @@ jobs:
|
||||
"release": {
|
||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
name: UI Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ui-svelte
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '24'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: ui-svelte/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Type check
|
||||
run: npm run check
|
||||
|
||||
- name: Run tests
|
||||
run: npm test
|
||||
@@ -0,0 +1,133 @@
|
||||
name: Build Unified Docker Image
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
llama_cpp_ref:
|
||||
description: "llama.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
whisper_ref:
|
||||
description: "whisper.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
sd_ref:
|
||||
description: "stable-diffusion.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
ik_llama_ref:
|
||||
description: "ik_llama.cpp commit hash, tag, or branch (CUDA only)"
|
||||
required: false
|
||||
default: "main"
|
||||
llama_swap_version:
|
||||
description: "llama-swap version (e.g. v198, latest, main)"
|
||||
required: false
|
||||
default: "main"
|
||||
build_cuda:
|
||||
description: "Build CUDA image"
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
build_vulkan:
|
||||
description: "Build Vulkan image"
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- id: set-matrix
|
||||
run: |
|
||||
backends=()
|
||||
# schedule uses defaults (build both); workflow_dispatch respects inputs
|
||||
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_cuda }}" == "true" ]]; then
|
||||
backends+=("cuda")
|
||||
fi
|
||||
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_vulkan }}" == "true" ]]; then
|
||||
backends+=("vulkan")
|
||||
fi
|
||||
matrix=$(printf '%s\n' "${backends[@]}" | jq -R . | jq -sc .)
|
||||
echo "matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
|
||||
build:
|
||||
needs: setup
|
||||
if: ${{ needs.setup.outputs.matrix != '[]' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
|
||||
variant:
|
||||
- name: root
|
||||
uid: "0"
|
||||
suffix: ""
|
||||
- name: rootless
|
||||
uid: "10001"
|
||||
suffix: "-rootless"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
# On GitHub Actions runners, create a fresh builder.
|
||||
# When running locally under act, skip this and reuse the existing
|
||||
# llama-swap-builder (which has ccache warm) to avoid exhausting disk.
|
||||
- name: Set up Docker Buildx
|
||||
if: ${{ !env.ACT }}
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
if: ${{ !env.ACT }}
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build unified Docker image (${{ matrix.backend }}, ${{ matrix.variant.name }})
|
||||
env:
|
||||
LLAMA_REF: ${{ inputs.llama_cpp_ref || 'master' }}
|
||||
WHISPER_REF: ${{ inputs.whisper_ref || 'master' }}
|
||||
SD_REF: ${{ inputs.sd_ref || 'master' }}
|
||||
IK_LLAMA_REF: ${{ inputs.ik_llama_ref || 'main' }}
|
||||
LS_VERSION: ${{ inputs.llama_swap_version || 'main' }}
|
||||
RUN_UID: ${{ matrix.variant.uid }}
|
||||
DOCKER_IMAGE_TAG: ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}${{ matrix.variant.suffix }}
|
||||
# When running under act, use the local builder that has warm ccache.
|
||||
# On GitHub Actions, BUILDX_BUILDER is unset so docker uses the builder
|
||||
# created by setup-buildx-action above.
|
||||
BUILDX_BUILDER: ${{ env.ACT == 'true' && 'llama-swap-builder' || '' }}
|
||||
run: |
|
||||
chmod +x docker/unified/build-image.sh
|
||||
docker/unified/build-image.sh --${{ matrix.backend }}
|
||||
|
||||
- name: Push to GitHub Container Registry
|
||||
if: ${{ !env.ACT }}
|
||||
run: |
|
||||
TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}${{ matrix.variant.suffix }}"
|
||||
docker push "${TAG}"
|
||||
DATE_TAG=$(date -u +%Y-%m-%d)
|
||||
docker tag "${TAG}" "${TAG}-${DATE_TAG}"
|
||||
docker push "${TAG}-${DATE_TAG}"
|
||||
@@ -0,0 +1,51 @@
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and svelt5 for UI (located in ui/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use the github CLI `gh` to create pull requests and work with github
|
||||
- Rules for creating pull requests:
|
||||
- keep them short and focused on changes.
|
||||
- never include a test plan
|
||||
- write the summary using the same style rules as commit message
|
||||
|
||||
## Testing
|
||||
|
||||
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
|
||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||
|
||||
### Commit message example format:
|
||||
|
||||
```
|
||||
proxy: add new feature
|
||||
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
- key change 1
|
||||
- key change 2
|
||||
- key change 3
|
||||
|
||||
fixes #123
|
||||
```
|
||||
|
||||
## Code Reviews
|
||||
|
||||
- use three levels High, Medium, Low severity
|
||||
- label each discovered issue with a label like H1, M2, L3 respectively
|
||||
- High severity are must fix issues (security, race conditions, critical bugs)
|
||||
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
|
||||
- Low severity are nice to have changes and nits
|
||||
- Include a suggestion with each discovered item
|
||||
- Limit your code review to three items with the highest priority first
|
||||
- Double check your discovered items and recommended remediations
|
||||
@@ -1,43 +1 @@
|
||||
# Project: llama-swap
|
||||
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and react for UI (ui/)
|
||||
|
||||
## Testing
|
||||
|
||||
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
### Plan Improvements
|
||||
|
||||
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
||||
|
||||
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
||||
|
||||
- Identify any inconsistencies.
|
||||
- Expand plans out to be detailed specification of requirements and changes to be made.
|
||||
- Plans should have at least these sections:
|
||||
- Title - very short, describes changes
|
||||
- Overview: A more detailed summary of goal and outcomes desired
|
||||
- Design Requirements: Detailed descriptions of what needs to be done
|
||||
- Testing Plan: Tests to be implemented
|
||||
- Checklist: A detailed list of changes to be made
|
||||
|
||||
Look for "plan expansion" as explicit instructions to improve a plan.
|
||||
|
||||
### Implementation of plans
|
||||
|
||||
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
||||
|
||||
## General Rules
|
||||
|
||||
- when summarizing changes only include details that require further action (action items)
|
||||
- when there are no action items, just say "Done."
|
||||
@AGENTS.md
|
||||
|
||||
@@ -36,11 +36,11 @@ test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -race -count=1 ./proxy/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui && npm install
|
||||
cd ui-svelte && npm install
|
||||
|
||||
# build react UI
|
||||
ui: ui/node_modules
|
||||
cd ui && npm run build
|
||||
cd ui-svelte && npm run build
|
||||
|
||||
# Build OSX binary
|
||||
mac: ui
|
||||
@@ -51,7 +51,7 @@ mac: ui
|
||||
linux: ui
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||
#GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||
|
||||
# Build Windows binary
|
||||
windows: ui
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
# llama-swap
|
||||
|
||||
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
|
||||
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
|
||||
|
||||
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
||||
|
||||
@@ -13,18 +13,29 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
|
||||
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
|
||||
- future proof, upgrade your inference servers at any time.
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/responses`
|
||||
- `v1/embeddings`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- `v1/audio/voices`
|
||||
- `v1/images/generations`
|
||||
- `v1/images/edits`
|
||||
- ✅ Anthropic API supported endpoints:
|
||||
- `v1/messages`
|
||||
- `v1/messages/count_tokens`
|
||||
- ✅ llama-server (llama.cpp) supported endpoints
|
||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||
- `/infill` - for code infilling
|
||||
- `/completion` - for completion endpoint
|
||||
- ✅ SDAPI via [stable-diffusion.cpp's server](https://github.com/leejet/stable-diffusion.cpp/tree/master/examples/server)
|
||||
- `/sdapi/v1/txt2img`
|
||||
- `/sdapi/v1/img2img`
|
||||
- `/sdapi/v1/loras` - requires `model` in request body to fetch the correct loras
|
||||
- ✅ llama-swap API
|
||||
- `/ui` - web UI
|
||||
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
@@ -32,6 +43,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- `/log` - remote log monitoring
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||
- ✅ Customizable
|
||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- Automatic unloading of models after timeout by setting a `ttl`
|
||||
@@ -40,14 +52,27 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
|
||||
### Web UI
|
||||
|
||||
llama-swap includes a real time web interface for monitoring logs and controlling models:
|
||||
llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
|
||||
|
||||
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
|
||||
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
|
||||
|
||||
View detailed token metrics:
|
||||
|
||||
<img width="1111" height="515" alt="image" src="https://github.com/user-attachments/assets/64bfb280-d7a3-4126-971a-a128fd40410c" />
|
||||
|
||||
Inspect request and responses:
|
||||
|
||||
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
|
||||
|
||||
Manually load and unload models:
|
||||
|
||||
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
|
||||
|
||||
|
||||
The Activity Page shows recent requests:
|
||||
Real time log streaming:
|
||||
|
||||
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
|
||||
|
||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -61,7 +86,8 @@ llama-swap can be installed in multiple ways
|
||||
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc).
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
|
||||
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
|
||||
|
||||
```shell
|
||||
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
@@ -71,6 +97,14 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
|
||||
# configuration hot reload supported with a
|
||||
# directory volume mount
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
-v /path/to/config:/config \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
|
||||
```
|
||||
|
||||
<details>
|
||||
@@ -89,6 +123,9 @@ docker pull ghcr.io/mostlygeek/llama-swap:musa
|
||||
# tagged llama-swap, platform and llama-server version images
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
||||
|
||||
# non-root cuda
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -191,23 +228,26 @@ As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. H
|
||||
|
||||
## Monitoring Logs on the CLI
|
||||
|
||||
```shell
|
||||
```sh
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
$ curl http://host/logs
|
||||
|
||||
# streams combined logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
curl -Ns http://host/logs/stream
|
||||
|
||||
# just llama-swap's logs
|
||||
curl -Ns 'http://host/logs/stream/proxy'
|
||||
# stream llama-swap's proxy status logs
|
||||
curl -Ns http://host/logs/stream/proxy
|
||||
|
||||
# just upstream's logs
|
||||
curl -Ns 'http://host/logs/stream/upstream'
|
||||
# stream logs from upstream processes that llama-swap loads
|
||||
curl -Ns http://host/logs/stream/upstream
|
||||
|
||||
# stream logs only from a specific model
|
||||
curl -Ns http://host/logs/stream/{model_id}
|
||||
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
# skips history and just streams new log entries
|
||||
# appending ?no-history will disable sending buffered history first
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# Replace ring.Ring with Efficient Circular Byte Buffer
|
||||
|
||||
## Overview
|
||||
|
||||
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
|
||||
|
||||
## Current Issues
|
||||
|
||||
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
|
||||
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
|
||||
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
|
||||
4. Linked list structure has poor cache locality and pointer overhead
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### New CircularBuffer Type
|
||||
|
||||
Create a simple circular byte buffer with:
|
||||
- Single pre-allocated `[]byte` of fixed capacity (10KB)
|
||||
- `head` and `size` integers to track write position and data length
|
||||
- No per-write allocations
|
||||
|
||||
### API Requirements
|
||||
|
||||
The new buffer must support:
|
||||
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
|
||||
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
|
||||
|
||||
### Implementation Details
|
||||
|
||||
```go
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
```
|
||||
|
||||
**Write logic:**
|
||||
- If `len(p) >= capacity`: just keep the last `capacity` bytes
|
||||
- Otherwise: write bytes at `head`, wrapping around if needed
|
||||
- Update `head` and `size` accordingly
|
||||
- Data is copied into the internal buffer (not stored by reference)
|
||||
|
||||
**GetHistory logic:**
|
||||
- Calculate start position: `(head - size + cap) % cap`
|
||||
- If not wrapped: single slice copy
|
||||
- If wrapped: two copies (end of buffer + beginning)
|
||||
- Returns a **new slice** (copy), not a view into internal buffer
|
||||
|
||||
### Immutability Guarantees (must preserve)
|
||||
|
||||
Per existing tests:
|
||||
1. Modifying input `[]byte` after `Write()` must not affect stored data
|
||||
2. `GetHistory()` returns independent copy - modifications don't affect buffer
|
||||
|
||||
## Files to Modify
|
||||
|
||||
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
|
||||
|
||||
## Testing Plan
|
||||
|
||||
Existing tests in `logMonitor_test.go` should continue to pass:
|
||||
- `TestLogMonitor` - Basic write/read and subscriber notification
|
||||
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
|
||||
- `TestWrite_LogTimeFormat` - Timestamp formatting
|
||||
|
||||
Add new tests:
|
||||
- Test buffer wrap-around behavior
|
||||
- Test large writes that exceed buffer capacity
|
||||
- Test exact capacity boundary conditions
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Create `circularBuffer` struct in `logMonitor.go`
|
||||
- [ ] Implement `Write()` method for circular buffer
|
||||
- [ ] Implement `GetHistory()` method for circular buffer
|
||||
- [ ] Update `LogMonitor` struct to use new buffer
|
||||
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
|
||||
- [ ] Update `LogMonitor.Write()` to use new buffer
|
||||
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
|
||||
- [ ] Remove `"container/ring"` import
|
||||
- [ ] Run `make test-dev` to verify existing tests pass
|
||||
- [ ] Add wrap-around test case
|
||||
- [ ] Run `make test-all` for final validation
|
||||
@@ -210,6 +210,11 @@ func main() {
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/v1/audio/voices", func(c *gin.Context) {
|
||||
model := c.Query("model")
|
||||
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
@@ -269,6 +274,43 @@ func main() {
|
||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||
})
|
||||
|
||||
// SD API endpoints
|
||||
r.POST("/sdapi/v1/txt2img", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": modelName,
|
||||
"images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
r.POST("/sdapi/v1/img2img", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": modelName,
|
||||
"images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/sdapi/v1/loras", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"loras": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
address := "127.0.0.1:" + *port // Address with the specified port
|
||||
|
||||
srv := &http.Server{
|
||||
|
||||
@@ -39,6 +39,43 @@
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
|
||||
},
|
||||
"timeouts": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connect": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP connection timeout in seconds. Set to 0 to disable (not recommended)."
|
||||
},
|
||||
"responseHeader": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 60,
|
||||
"description": "Time to wait for response headers in seconds. Set to 0 to disable (not recommended)."
|
||||
},
|
||||
"tlsHandshake": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 10,
|
||||
"description": "TLS handshake timeout in seconds. Set to 0 to disable (not recommended)."
|
||||
},
|
||||
"expectContinue": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 1,
|
||||
"description": "Expect-Continue timeout in seconds. Set to 0 to disable (not recommended)."
|
||||
},
|
||||
"idleConn": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 90,
|
||||
"description": "Idle connection timeout in seconds. Set to 0 to disable (not recommended)."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections."
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
@@ -48,6 +85,12 @@
|
||||
"default": 120,
|
||||
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
||||
},
|
||||
"globalTTL": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Default TTL for all models in seconds, 0 means no TTL and models will never be automatically unloaded"
|
||||
},
|
||||
"logLevel": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -87,6 +130,12 @@
|
||||
"default": 1000,
|
||||
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
||||
},
|
||||
"captureBuffer": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 5,
|
||||
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
|
||||
},
|
||||
"startPort": {
|
||||
"type": "integer",
|
||||
"default": 5800,
|
||||
@@ -171,9 +220,9 @@
|
||||
},
|
||||
"ttl": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
|
||||
"minimum": -1,
|
||||
"default": -1,
|
||||
"description": "Automatically unload the model after ttl seconds. -1 uses the global TTL value, 0 disables unloading. Must be >0 to enable."
|
||||
},
|
||||
"useModelName": {
|
||||
"type": "string",
|
||||
@@ -188,11 +237,26 @@
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
},
|
||||
"setParamsByID": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
},
|
||||
"default": {},
|
||||
"description": "Dictionary mapping requested model IDs (or aliases) to parameters to set/override in requests. Applied after setParams and can override those values. Useful with aliases to vary behaviour depending on which alias the client used (e.g. different reasoning_effort per alias). Keys support ${MODEL_ID} macro substitution. Protected params like 'model' cannot be overridden."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings. Only stripParams is supported."
|
||||
"description": "Dictionary of filter settings. Supports stripParams, setParams, and setParamsByID."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
@@ -214,6 +278,9 @@
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
|
||||
},
|
||||
"timeouts": {
|
||||
"$ref": "#/definitions/timeouts"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -273,6 +340,109 @@
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||
},
|
||||
"logToStdout": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"proxy",
|
||||
"upstream",
|
||||
"both",
|
||||
"none"
|
||||
],
|
||||
"default": "proxy",
|
||||
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
|
||||
},
|
||||
"apiKeys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
|
||||
},
|
||||
"peers": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"proxy",
|
||||
"models"
|
||||
],
|
||||
"properties": {
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
|
||||
},
|
||||
"apiKey": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"description": "A list of models served by the peer."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
|
||||
},
|
||||
"timeouts": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connect": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 30,
|
||||
"description": "TCP connection timeout in seconds."
|
||||
},
|
||||
"responseHeader": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 60,
|
||||
"description": "Time to wait for response headers in seconds."
|
||||
},
|
||||
"tlsHandshake": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 10,
|
||||
"description": "TLS handshake timeout in seconds."
|
||||
},
|
||||
"idleConn": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 90,
|
||||
"description": "Idle connection timeout in seconds."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections to this peer."
|
||||
}
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -34,12 +34,27 @@ logLevel: info
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||
# - optional, default: 10
|
||||
# - set to 0 to disable
|
||||
captureBuffer: 15
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
@@ -60,6 +75,11 @@ sendLoadingState: true
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# globalTTL: the default TTL in seconds before unloading a model
|
||||
# - optional, default: 0 (never automatically unload)
|
||||
# - must be >= 0
|
||||
globalTTL: 0
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
@@ -70,6 +90,9 @@ includeAliasesInList: false
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||
# - env macros are substituted first, before regular macros
|
||||
# - if the env var is not set, config loading will fail with an error
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
@@ -82,6 +105,24 @@ macros:
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# Example of environment variable macros
|
||||
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||
# - useful for paths, secrets, or machine-specific configuration
|
||||
"models_dir": "${env.HOME}/models"
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
|
||||
# use environment variable macros to keep secrets out of the config
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
@@ -90,7 +131,7 @@ macros:
|
||||
# - below are examples of the all the settings a model can have
|
||||
models:
|
||||
# keys are the model names used in API requests
|
||||
"llama":
|
||||
"gpt-oss-120b":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
@@ -107,7 +148,7 @@ models:
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/llama-8B-Q4_K_M.gguf
|
||||
--model path/to/gpt-oss-120B.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
@@ -115,13 +156,13 @@ models:
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "llama 3.1 8B"
|
||||
name: "gpt-oss 120B"
|
||||
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A small but capable model used for quick testing"
|
||||
description: "A thinking model from OpenAI"
|
||||
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
@@ -136,14 +177,6 @@ models:
|
||||
# - if you use a custom port in cmd this *must* be set
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
@@ -152,8 +185,10 @@ models:
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: 0
|
||||
# - ttl values must be a value greater than 0
|
||||
# - optional, default: -1 (use global default)
|
||||
# - ttl values must be a value greater than or equal to 0
|
||||
# - a ttl of -1 will use the global TTL value as the default
|
||||
# - a ttl of 0 will mean never unload
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
@@ -161,11 +196,11 @@ models:
|
||||
# - optional, default: ""
|
||||
# - useful for when the upstream server expects a specific model name that
|
||||
# is different from the model's ID
|
||||
useModelName: "qwen:qwq"
|
||||
useModelName: "openai/gpt-oss-120B"
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - only stripParams is currently supported
|
||||
# - same capabilities as peer filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
@@ -175,6 +210,43 @@ models:
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for enforcing specific parameter values
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - always runs for the model
|
||||
setParams:
|
||||
# Example: enforce specific sampling parameters
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
|
||||
# setParamsByID: a dictionary of parameters to set based the model ID
|
||||
# - optional, default: empty dictionary
|
||||
# - combine with aliases to create variant behaviour without reloading the model
|
||||
# - parameters are set in the request body JSON
|
||||
# - run after setParams so it will override any settings
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - model aliases will be automatically created for each key
|
||||
setParamsByID:
|
||||
"${MODEL_ID}":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: medium
|
||||
"${MODEL_ID}:high":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: low
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
@@ -212,6 +284,21 @@ models:
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# timeouts: configure proxy connection timeouts for this model
|
||||
# - optional, defaults shown below
|
||||
# - useful for models running on slower hardware that need longer timeouts
|
||||
# - connect: TCP connection timeout in seconds
|
||||
# - responseHeader: time to wait for response headers in seconds
|
||||
# (increasing this helps avoid 502 errors on slow hardware)
|
||||
# - tlsHandshake: TLS handshake timeout in seconds
|
||||
# - idleConn: idle connection timeout in seconds
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
@@ -321,3 +408,66 @@ hooks:
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
# - can be a string or a macro
|
||||
apiKey: ${env.OPENROUTER_API_KEY}
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
# timeouts: configure proxy connection timeouts for this peer
|
||||
# - optional, defaults shown below
|
||||
# - useful when the peer runs on slower hardware
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# filters: a dictionary of filter settings for peer requests
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as model filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for removing parameters that the peer doesn't support
|
||||
# - the `model` parameter can never be removed
|
||||
stripParams: "temperature, top_p"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for injecting provider-specific settings like data retention policies
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce zero-data-retention for OpenRouter
|
||||
provider:
|
||||
data_collection: "deny"
|
||||
zdr: true
|
||||
|
||||
@@ -1,55 +1,164 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
# use this to test locally, example:
|
||||
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
|
||||
# you need read:package scope on the token. Generate a personal access token with
|
||||
# the scopes: gist, read:org, repo, write:packages
|
||||
# then: gh auth login (and copy/paste the new token)
|
||||
|
||||
LOG_DEBUG=${LOG_DEBUG:-0}
|
||||
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
|
||||
|
||||
log_debug() {
|
||||
if [ "$LOG_DEBUG" = "1" ]; then
|
||||
echo "[DEBUG] $*"
|
||||
fi
|
||||
}
|
||||
|
||||
log_info() {
|
||||
echo "[INFO] $*"
|
||||
}
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cuda13" "cpu" "rocm")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
|
||||
log_info "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
||||
# variable, this permits testing with forked llama.cpp repositories
|
||||
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
||||
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
|
||||
|
||||
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
||||
# to enable easy container builds on forked repos
|
||||
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
# Fetches the most recent llama.cpp tag matching the given prefix
|
||||
# Handles pagination to search beyond the first 100 results
|
||||
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
|
||||
# Returns: the version number extracted from the tag
|
||||
fetch_llama_tag() {
|
||||
local tag_prefix=$1
|
||||
local page=1
|
||||
local per_page=100
|
||||
|
||||
while true; do
|
||||
log_debug "Fetching page $page for tag prefix: $tag_prefix"
|
||||
|
||||
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
|
||||
|
||||
# Check for API errors
|
||||
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
|
||||
local error_msg=$(echo "$response" | jq -r '.message')
|
||||
log_info "GitHub API error: $error_msg"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Check if response is empty array (no more pages)
|
||||
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
|
||||
log_debug "No more pages (empty response)"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract matching tag from this page
|
||||
local found_tag=$(echo "$response" | jq -r \
|
||||
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
|
||||
| sort -r | head -n1)
|
||||
|
||||
if [ -n "$found_tag" ]; then
|
||||
log_debug "Found tag: $found_tag on page $page"
|
||||
echo "$found_tag" | awk -F '-' '{print $NF}'
|
||||
return 0
|
||||
fi
|
||||
|
||||
page=$((page + 1))
|
||||
|
||||
# Safety limit to prevent infinite loops
|
||||
if [ $page -gt 50 ]; then
|
||||
log_info "Reached pagination safety limit (50 pages)"
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the server tag
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r '.[] | select(.metadata.container.tags[] | startswith("server")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
LCPP_TAG=$(fetch_llama_tag "server")
|
||||
BASE_TAG=server-${LCPP_TAG}
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
|
||||
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||
fi
|
||||
|
||||
SD_TAG=master-${ARCH}
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
log_info "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
else
|
||||
log_info "LCPP_TAG: $LCPP_TAG"
|
||||
fi
|
||||
|
||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
|
||||
log_info "Abort: DEBUG_ABORT_BUILD set"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for CONTAINER_TYPE in non-root root; do
|
||||
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
|
||||
USER_UID=0
|
||||
USER_GID=0
|
||||
USER_HOME=/root
|
||||
|
||||
if [ "$CONTAINER_TYPE" == "non-root" ]; then
|
||||
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
|
||||
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
|
||||
USER_UID=10001
|
||||
USER_GID=10001
|
||||
USER_HOME=/app
|
||||
fi
|
||||
|
||||
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||
|
||||
# For architectures with stable-diffusion.cpp support, layer sd-server on top
|
||||
case "$ARCH" in
|
||||
"musa" | "vulkan")
|
||||
log_info "Adding sd-server to $CONTAINER_TAG"
|
||||
docker build --provenance=false -f llama-swap-sd.Containerfile \
|
||||
--build-arg BASE=${CONTAINER_TAG} \
|
||||
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
|
||||
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
|
||||
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
|
||||
esac
|
||||
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
done
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Build script for llama-swap-docker with commit hash pinning
|
||||
#
|
||||
# Usage:
|
||||
# ./build-image.sh --cuda # Build CUDA image
|
||||
# ./build-image.sh --vulkan # Build Vulkan image
|
||||
# ./build-image.sh --cuda --no-cache # Build CUDA image without cache
|
||||
# LLAMA_COMMIT_HASH=abc123 ./build-image.sh --cuda # Override llama.cpp commit
|
||||
# LLAMA_COMMIT_HASH=b8429 ./build-image.sh --vulkan # Override llama.cpp release tag (vulkan uses prebuilt binaries)
|
||||
# WHISPER_COMMIT_HASH=def456 ./build-image.sh --vulkan # Override whisper.cpp commit
|
||||
# SD_COMMIT_HASH=ghi789 ./build-image.sh --cuda # Override stable-diffusion.cpp commit
|
||||
#
|
||||
# Features:
|
||||
# - Auto-detects latest commit hashes from git repos
|
||||
# - Builds llama-swap from local source code
|
||||
# - Allows environment variable overrides for reproducible builds
|
||||
# - Cache-friendly: changing commit hash busts cache appropriately
|
||||
# - Supports both CUDA and Vulkan backends (requires explicit flag)
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Parse command line arguments
|
||||
BACKEND=""
|
||||
NO_CACHE=false
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
echo ""
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
|
||||
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
|
||||
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
|
||||
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--cuda)
|
||||
BACKEND="cuda"
|
||||
;;
|
||||
--vulkan)
|
||||
BACKEND="vulkan"
|
||||
;;
|
||||
--no-cache)
|
||||
NO_CACHE=true
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
|
||||
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
|
||||
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
|
||||
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate backend selection
|
||||
if [[ -z "$BACKEND" ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Configuration
|
||||
if [[ -n "${DOCKER_IMAGE_TAG:-}" ]]; then
|
||||
# User provided a custom tag, use it as-is
|
||||
:
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
DOCKER_IMAGE_TAG="llama-swap:vulkan"
|
||||
else
|
||||
DOCKER_IMAGE_TAG="llama-swap:cuda"
|
||||
fi
|
||||
DOCKER_BUILDKIT="${DOCKER_BUILDKIT:-1}"
|
||||
|
||||
# Single unified Dockerfile, backend selected via build arg
|
||||
DOCKERFILE="Dockerfile"
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Building for: Vulkan (AMD GPUs and compatible hardware)"
|
||||
else
|
||||
echo "Building for: CUDA (NVIDIA GPUs)"
|
||||
fi
|
||||
|
||||
# Git repository URLs
|
||||
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
|
||||
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
|
||||
|
||||
# Function to get the latest commit hash from a git repo's default branch
|
||||
get_latest_commit() {
|
||||
local repo_url="$1"
|
||||
local branch="${2:-master}"
|
||||
|
||||
# Try to get the latest commit hash for the specified branch
|
||||
git ls-remote --heads "${repo_url}" "${branch}" 2>/dev/null | head -1 | cut -f1
|
||||
}
|
||||
|
||||
# Function to get the default branch name (master or main)
|
||||
get_default_branch() {
|
||||
local repo_url="$1"
|
||||
|
||||
# Check for master first
|
||||
if git ls-remote --heads "${repo_url}" master &>/dev/null; then
|
||||
echo "master"
|
||||
elif git ls-remote --heads "${repo_url}" main &>/dev/null; then
|
||||
echo "main"
|
||||
else
|
||||
echo "master" # fallback
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get the latest release tag from a GitHub repo
|
||||
get_latest_release_tag() {
|
||||
local owner_repo="$1"
|
||||
curl -fsSL "https://api.github.com/repos/${owner_repo}/releases/latest" \
|
||||
| grep '"tag_name"' | head -1 | cut -d'"' -f4
|
||||
}
|
||||
|
||||
echo "=========================================="
|
||||
echo "llama-swap-docker Build Script"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Determine commit hashes / release tags - use env vars or auto-detect
|
||||
# For vulkan builds, llama and sd use GitHub release tags (prebuilt binaries).
|
||||
# For cuda builds (or whisper on any backend), use git commit hashes.
|
||||
if [[ -n "${LLAMA_COMMIT_HASH:-}" ]]; then
|
||||
LLAMA_HASH="${LLAMA_COMMIT_HASH}"
|
||||
echo "llama.cpp: Using provided version: ${LLAMA_HASH}"
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
LLAMA_HASH=$(get_latest_release_tag "ggml-org/llama.cpp")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest release tag for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: Auto-detected latest release tag: ${LLAMA_HASH}"
|
||||
else
|
||||
LLAMA_BRANCH=$(get_default_branch "${LLAMA_REPO}")
|
||||
LLAMA_HASH=$(get_latest_commit "${LLAMA_REPO}" "${LLAMA_BRANCH}")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: Auto-detected latest commit (${LLAMA_BRANCH}): ${LLAMA_HASH}"
|
||||
fi
|
||||
|
||||
if [[ -n "${WHISPER_COMMIT_HASH:-}" ]]; then
|
||||
WHISPER_HASH="${WHISPER_COMMIT_HASH}"
|
||||
echo "whisper.cpp: Using provided commit hash: ${WHISPER_HASH}"
|
||||
else
|
||||
WHISPER_BRANCH=$(get_default_branch "${WHISPER_REPO}")
|
||||
WHISPER_HASH=$(get_latest_commit "${WHISPER_REPO}" "${WHISPER_BRANCH}")
|
||||
if [[ -z "${WHISPER_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "whisper.cpp: Auto-detected latest commit (${WHISPER_BRANCH}): ${WHISPER_HASH}"
|
||||
fi
|
||||
|
||||
if [[ -n "${SD_COMMIT_HASH:-}" ]]; then
|
||||
SD_HASH="${SD_COMMIT_HASH}"
|
||||
echo "stable-diffusion.cpp: Using provided version: ${SD_HASH}"
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
SD_HASH=$(get_latest_release_tag "leejet/stable-diffusion.cpp")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest release tag for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: Auto-detected latest release tag: ${SD_HASH}"
|
||||
else
|
||||
SD_BRANCH=$(get_default_branch "${SD_REPO}")
|
||||
SD_HASH=$(get_latest_commit "${SD_REPO}" "${SD_BRANCH}")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: Auto-detected latest commit (${SD_BRANCH}): ${SD_HASH}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting Docker build..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Build the Docker image with commit hashes as build args
|
||||
# Build context is the repository root (..) so the Dockerfile can access Go source
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
BUILD_ARGS=(
|
||||
--build-arg "BACKEND=${BACKEND}"
|
||||
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
|
||||
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
|
||||
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
|
||||
-t "${DOCKER_IMAGE_TAG}"
|
||||
-f "${SCRIPT_DIR}/${DOCKERFILE}"
|
||||
)
|
||||
|
||||
if [[ "$NO_CACHE" == true ]]; then
|
||||
BUILD_ARGS+=(--no-cache)
|
||||
echo "Note: Building without cache"
|
||||
fi
|
||||
|
||||
# Use docker buildx with a custom builder for parallelism control
|
||||
# The legacy DOCKER_BUILDKIT=1 docker build doesn't respect BUILDKIT_MAX_PARALLELISM env var
|
||||
# We need to use a custom builder with a buildkitd.toml config file
|
||||
BUILDER_NAME="llama-swap-builder"
|
||||
|
||||
# Check if our custom builder exists with the right config, create/update if needed
|
||||
if ! docker buildx inspect "$BUILDER_NAME" >/dev/null 2>&1; then
|
||||
echo "Creating custom buildx builder with max-parallelism=1..."
|
||||
|
||||
# Create buildkitd.toml config file
|
||||
cat > buildkitd.toml << 'BUILDKIT_EOF'
|
||||
[worker.oci]
|
||||
max-parallelism = 1
|
||||
BUILDKIT_EOF
|
||||
|
||||
# Create the builder with the config
|
||||
docker buildx create --name "$BUILDER_NAME" \
|
||||
--driver docker-container \
|
||||
--buildkitd-config buildkitd.toml \
|
||||
--use
|
||||
else
|
||||
# Switch to our builder
|
||||
docker buildx use "$BUILDER_NAME"
|
||||
fi
|
||||
|
||||
echo "Building with sequential stages (one at a time), each using all CPU cores..."
|
||||
echo "Using builder: $BUILDER_NAME"
|
||||
|
||||
# Use docker buildx build with --load to load the image into Docker
|
||||
# The --builder flag ensures we use our custom builder with max-parallelism=1
|
||||
# Build context is the repository root so we can access Go source files
|
||||
docker buildx build --builder "$BUILDER_NAME" --load "${BUILD_ARGS[@]}" "${REPO_ROOT}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verifying build artifacts..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Verify all expected binaries exist in the image
|
||||
MISSING_BINARIES=()
|
||||
|
||||
for binary in llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap; do
|
||||
if ! docker run --rm "${DOCKER_IMAGE_TAG}" which "${binary}" >/dev/null 2>&1; then
|
||||
MISSING_BINARIES+=("${binary}")
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
|
||||
echo "ERROR: Build succeeded but the following binaries are missing from the image:"
|
||||
for binary in "${MISSING_BINARIES[@]}"; do
|
||||
echo " - ${binary}"
|
||||
done
|
||||
echo ""
|
||||
echo "This usually indicates a build stage failure. Try running with --no-cache flag:"
|
||||
echo " ./build-image.sh --vulkan --no-cache"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All expected binaries verified: llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Build complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Image tag: ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Built with:"
|
||||
echo " llama.cpp: ${LLAMA_HASH}"
|
||||
echo " whisper.cpp: ${WHISPER_HASH}"
|
||||
echo " stable-diffusion.cpp: ${SD_HASH}"
|
||||
echo " llama-swap: $(docker run --rm "${DOCKER_IMAGE_TAG}" cat /versions.txt | grep llama-swap | cut -d' ' -f2-)"
|
||||
echo ""
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Note: For AMD GPUs, you may also need to mount render devices:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
|
||||
else
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
|
||||
fi
|
||||
@@ -15,4 +15,19 @@ models:
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
--port 9999
|
||||
|
||||
z-image:
|
||||
checkEndpoint: /
|
||||
cmd: |
|
||||
/app/sd-server
|
||||
--listen-port 9999
|
||||
--diffusion-fa
|
||||
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||
--vae /models/ae.safetensors
|
||||
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||
--offload-to-cpu
|
||||
--cfg-scale 1.0
|
||||
--height 512 --width 512
|
||||
--steps 8
|
||||
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
|
||||
@@ -0,0 +1,11 @@
|
||||
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
|
||||
ARG SD_TAG=master-vulkan
|
||||
ARG BASE=llama-swap:latest
|
||||
|
||||
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
|
||||
FROM ${BASE}
|
||||
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
|
||||
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
|
||||
@@ -1,8 +1,10 @@
|
||||
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=170
|
||||
ARG LS_REPO=mostlygeek/llama-swap
|
||||
|
||||
# Set default UID/GID arguments
|
||||
ARG UID=10001
|
||||
@@ -27,10 +29,14 @@ RUN chown --recursive $UID:$GID $HOME /app
|
||||
USER $UID:$GID
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Add /app to PATH
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
RUN \
|
||||
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
||||
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
|
||||
|
||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
# Unified multi-stage Dockerfile for AI inference tools
|
||||
# Supports CUDA and Vulkan backends via BACKEND build arg
|
||||
#
|
||||
# Usage:
|
||||
# docker buildx build --build-arg BACKEND=cuda -t llama-swap:unified-cuda .
|
||||
# docker buildx build --build-arg BACKEND=vulkan -t llama-swap:unified-vulkan .
|
||||
# docker buildx build --build-arg BACKEND=cuda --build-arg CMAKE_CUDA_ARCHITECTURES="86;89" -t llama-swap:unified-cuda .
|
||||
#
|
||||
# Each project has its own install script that handles cloning, building,
|
||||
# and installing binaries. Build stages are independent for cache efficiency.
|
||||
|
||||
ARG BACKEND=cuda
|
||||
|
||||
# ── Builder bases ──────────────────────────────────────────────────────
|
||||
|
||||
FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 AS builder-base-cuda
|
||||
|
||||
ARG CMAKE_CUDA_ARCHITECTURES="60;61;75;86;89"
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}
|
||||
ENV CCACHE_DIR=/ccache
|
||||
ENV CCACHE_MAXSIZE=2G
|
||||
ENV PATH="/usr/lib/ccache:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git python3 python3-pip libssl-dev \
|
||||
curl ca-certificates ccache make wget \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# ──
|
||||
|
||||
FROM ubuntu:24.04 AS builder-base-vulkan
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV CCACHE_DIR=/ccache
|
||||
ENV CCACHE_MAXSIZE=2G
|
||||
ENV PATH="/usr/lib/ccache:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git python3 python3-pip libssl-dev \
|
||||
curl ca-certificates ccache make wget software-properties-common \
|
||||
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# ── Select builder base by BACKEND ────────────────────────────────────
|
||||
|
||||
FROM builder-base-${BACKEND} AS builder-base
|
||||
|
||||
# ── Build whisper.cpp (fastest build, run first) ──────────────────────
|
||||
|
||||
FROM builder-base AS whisper-build
|
||||
ARG BACKEND=cuda
|
||||
ARG WHISPER_COMMIT_HASH=master
|
||||
COPY install-whisper.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=whisper-${BACKEND},target=/src/whisper.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-whisper.sh "${WHISPER_COMMIT_HASH}"
|
||||
|
||||
# ── Build stable-diffusion.cpp ────────────────────────────────────────
|
||||
|
||||
FROM builder-base AS sd-build
|
||||
ARG BACKEND=cuda
|
||||
ARG SD_COMMIT_HASH=master
|
||||
COPY install-sd.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=sd-${BACKEND},target=/src/stable-diffusion.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-sd.sh "${SD_COMMIT_HASH}"
|
||||
|
||||
# ── Build llama.cpp (slowest build, run last) ─────────────────────────
|
||||
|
||||
FROM builder-base AS llama-build
|
||||
ARG BACKEND=cuda
|
||||
ARG LLAMA_COMMIT_HASH=master
|
||||
COPY install-llama.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=llama-${BACKEND},target=/src/llama.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-llama.sh "${LLAMA_COMMIT_HASH}"
|
||||
|
||||
# ── Build ik_llama.cpp (CUDA only) ────────────────────────────────────
|
||||
#
|
||||
# Two named stages allow ARG BACKEND to select at build time:
|
||||
# - ik-llama-cuda : real build (from builder-base-cuda)
|
||||
# - ik-llama-vulkan: no-op (empty /install/bin, skips CUDA pull entirely)
|
||||
# BuildKit only evaluates the selected branch, so vulkan builds never
|
||||
# pull nvidia/cuda:*-devel or compile ik_llama.cpp.
|
||||
|
||||
FROM builder-base-vulkan AS ik-llama-vulkan
|
||||
RUN mkdir -p /install/bin
|
||||
|
||||
FROM builder-base-cuda AS ik-llama-cuda
|
||||
ARG IK_LLAMA_COMMIT_HASH=main
|
||||
COPY install-ik-llama.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-cuda,target=/ccache \
|
||||
--mount=type=cache,id=ik-llama-cuda,target=/src/ik_llama.cpp/build \
|
||||
bash /build/install-ik-llama.sh "${IK_LLAMA_COMMIT_HASH}"
|
||||
|
||||
ARG BACKEND=cuda
|
||||
FROM ik-llama-${BACKEND} AS ik-llama-build
|
||||
|
||||
# ── Download llama-swap release binary ────────────────────────────────
|
||||
|
||||
FROM builder-base AS llama-swap-download
|
||||
ARG LS_VERSION=latest
|
||||
COPY install-llama-swap.sh /build/
|
||||
RUN bash /build/install-llama-swap.sh "${LS_VERSION}"
|
||||
|
||||
# ── Runtime bases ─────────────────────────────────────────────────────
|
||||
|
||||
FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04 AS runtime-cuda
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
|
||||
ENV PATH="/usr/local/bin:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgomp1 python3 curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# CUDA stub drivers for container compatibility
|
||||
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so
|
||||
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
|
||||
|
||||
# ──
|
||||
|
||||
FROM ubuntu:24.04 AS runtime-vulkan
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH="/usr/local/bin:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgomp1 libvulkan1 mesa-vulkan-drivers \
|
||||
python3 curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ── Select runtime base by BACKEND ────────────────────────────────────
|
||||
|
||||
FROM runtime-${BACKEND} AS runtime
|
||||
|
||||
ARG BACKEND=cuda
|
||||
ARG LLAMA_COMMIT_HASH=unknown
|
||||
ARG WHISPER_COMMIT_HASH=unknown
|
||||
ARG SD_COMMIT_HASH=unknown
|
||||
ARG IK_LLAMA_COMMIT_HASH=unknown
|
||||
ARG RUN_UID=0
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3-numpy python3-sentencepiece \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user when RUN_UID != 0
|
||||
RUN if [ "$RUN_UID" != "0" ]; then \
|
||||
groupadd --system --gid $RUN_UID llama-swap && \
|
||||
useradd --system --uid $RUN_UID --gid $RUN_UID \
|
||||
--home /app --shell /sbin/nologin llama-swap; \
|
||||
fi && \
|
||||
mkdir -p /etc/llama-swap/config && \
|
||||
chown -R ${RUN_UID}:${RUN_UID} /etc/llama-swap
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy whisper.cpp binaries and libraries
|
||||
COPY --from=whisper-build /install/bin/whisper-server /usr/local/bin/
|
||||
COPY --from=whisper-build /install/bin/whisper-cli /usr/local/bin/
|
||||
COPY --from=whisper-build /install/lib/ /usr/local/lib/
|
||||
|
||||
# Copy stable-diffusion.cpp binaries and libraries
|
||||
COPY --from=sd-build /install/bin/sd-server /usr/local/bin/
|
||||
COPY --from=sd-build /install/bin/sd-cli /usr/local/bin/
|
||||
COPY --from=sd-build /install/lib/ /usr/local/lib/
|
||||
|
||||
# Copy llama.cpp binaries (statically linked)
|
||||
COPY --from=llama-build /install/bin/llama-server /usr/local/bin/
|
||||
COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
|
||||
|
||||
# Copy ik-llama-server (CUDA only; empty copy for vulkan)
|
||||
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
|
||||
|
||||
# Copy llama-swap binary
|
||||
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
|
||||
COPY --from=llama-swap-download /install/llama-swap-version /tmp/
|
||||
|
||||
RUN ldconfig
|
||||
|
||||
COPY config.example.yaml /etc/llama-swap/config/config.yaml
|
||||
|
||||
# Version tracking
|
||||
RUN echo "llama.cpp: ${LLAMA_COMMIT_HASH}" > /versions.txt && \
|
||||
echo "whisper.cpp: ${WHISPER_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "stable-diffusion.cpp: ${SD_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "ik_llama.cpp: ${IK_LLAMA_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "llama-swap: $(cat /tmp/llama-swap-version)" >> /versions.txt && \
|
||||
echo "backend: ${BACKEND}" >> /versions.txt && \
|
||||
echo "build_timestamp: $(date -u +%Y-%m-%dT%H:%M:%SZ)" >> /versions.txt
|
||||
|
||||
RUN mkdir -p /models && chown ${RUN_UID}:${RUN_UID} /models
|
||||
WORKDIR /models
|
||||
USER ${RUN_UID}
|
||||
ENTRYPOINT ["llama-swap"]
|
||||
CMD ["-config", "/etc/llama-swap/config/config.yaml", "-listen", "0.0.0.0:8080"]
|
||||
@@ -0,0 +1,8 @@
|
||||
# Unified Docker Container
|
||||
|
||||
These scripts create a custom llama-swap container that contains:
|
||||
|
||||
- llama-server for LLMs, rerank and embedding model support
|
||||
- sd-server (stable-diffusion.cpp) for image generation
|
||||
- whisper.cpp for ASR
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Build script for unified container with version pinning
|
||||
#
|
||||
# Usage:
|
||||
# ./build-image.sh --cuda # Build CUDA image
|
||||
# ./build-image.sh --vulkan # Build Vulkan image
|
||||
# ./build-image.sh --cuda --no-cache # Build without cache
|
||||
# LLAMA_REF=b1234 ./build-image.sh --vulkan # Pin llama.cpp to a commit hash
|
||||
# LLAMA_REF=v1.2.3 ./build-image.sh --cuda # Pin llama.cpp to a tag
|
||||
# WHISPER_REF=v1.0.0 ./build-image.sh --vulkan # Pin whisper.cpp to a tag
|
||||
# SD_REF=master ./build-image.sh --cuda # Pin stable-diffusion.cpp to a branch
|
||||
# LS_VERSION=170 ./build-image.sh --cuda # Override llama-swap version
|
||||
# IK_LLAMA_REF=main ./build-image.sh --cuda # Pin ik_llama.cpp to main branch (CUDA only)
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BACKEND=""
|
||||
NO_CACHE=false
|
||||
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--cuda)
|
||||
BACKEND="cuda"
|
||||
;;
|
||||
--vulkan)
|
||||
BACKEND="vulkan"
|
||||
;;
|
||||
--no-cache)
|
||||
NO_CACHE=true
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:unified-cuda or llama-swap:unified-vulkan)"
|
||||
echo " LLAMA_REF Pin llama.cpp to a commit, tag, or branch"
|
||||
echo " WHISPER_REF Pin whisper.cpp to a commit, tag, or branch"
|
||||
echo " SD_REF Pin stable-diffusion.cpp to a commit, tag, or branch"
|
||||
echo " IK_LLAMA_REF Pin ik_llama.cpp to a commit, tag, or branch (CUDA only)"
|
||||
echo " LS_VERSION Override llama-swap version (e.g., '170' or 'latest')"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$BACKEND" ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
echo ""
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DOCKER_IMAGE_TAG="${DOCKER_IMAGE_TAG:-llama-swap:unified-${BACKEND}}"
|
||||
|
||||
# Git repository URLs
|
||||
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
|
||||
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
|
||||
LLAMA_SWAP_REPO="https://github.com/mostlygeek/llama-swap.git"
|
||||
IK_LLAMA_REPO="https://github.com/ikawrakow/ik_llama.cpp.git"
|
||||
|
||||
# Resolve a git ref (commit hash, tag, or branch) to a full commit hash.
|
||||
# Requires only: git, network access to the remote.
|
||||
resolve_ref() {
|
||||
local repo_url="$1"
|
||||
local ref="$2"
|
||||
|
||||
# Full 40-char SHA — use as-is
|
||||
if [[ "${ref}" =~ ^[0-9a-f]{40}$ ]]; then
|
||||
echo "${ref}"
|
||||
return
|
||||
fi
|
||||
|
||||
# Try tag then branch (exact match)
|
||||
local hash
|
||||
hash=$(git ls-remote "${repo_url}" "refs/tags/${ref}" "refs/heads/${ref}" 2>/dev/null | head -1 | cut -f1)
|
||||
if [[ -n "${hash}" ]]; then
|
||||
echo "${hash}"
|
||||
return
|
||||
fi
|
||||
|
||||
# Short hash (7+ chars): scan all refs for a SHA with this prefix
|
||||
if [[ "${ref}" =~ ^[0-9a-f]{7,}$ ]]; then
|
||||
hash=$(git ls-remote "${repo_url}" 2>/dev/null | grep "^${ref}" | head -1 | cut -f1)
|
||||
if [[ -n "${hash}" ]]; then
|
||||
echo "${hash}"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "ERROR: Could not resolve ref '${ref}' for ${repo_url}" >&2
|
||||
if [[ "${ref}" =~ ^[0-9a-f]+$ && ${#ref} -lt 7 ]]; then
|
||||
echo " Short hashes must be at least 7 characters (got ${#ref})." >&2
|
||||
else
|
||||
echo " Tried: tag, branch, git ls-remote prefix match" >&2
|
||||
fi
|
||||
echo " Use a full 40-char SHA, a tag name, a branch name, or a 7+ char short hash." >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Resolve HEAD of a repo without needing to know the default branch name.
|
||||
get_latest_hash() {
|
||||
git ls-remote "${1}" HEAD 2>/dev/null | head -1 | cut -f1
|
||||
}
|
||||
|
||||
echo "=========================================="
|
||||
echo "llama-swap Unified Build (${BACKEND})"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Resolve llama.cpp ref
|
||||
if [[ -n "${LLAMA_REF:-}" ]]; then
|
||||
LLAMA_HASH=$(resolve_ref "${LLAMA_REPO}" "${LLAMA_REF}") || exit 1
|
||||
echo "llama.cpp: ${LLAMA_REF} -> ${LLAMA_HASH}"
|
||||
else
|
||||
LLAMA_HASH=$(get_latest_hash "${LLAMA_REPO}")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: latest HEAD: ${LLAMA_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve whisper.cpp ref
|
||||
if [[ -n "${WHISPER_REF:-}" ]]; then
|
||||
WHISPER_HASH=$(resolve_ref "${WHISPER_REPO}" "${WHISPER_REF}") || exit 1
|
||||
echo "whisper.cpp: ${WHISPER_REF} -> ${WHISPER_HASH}"
|
||||
else
|
||||
WHISPER_HASH=$(get_latest_hash "${WHISPER_REPO}")
|
||||
if [[ -z "${WHISPER_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "whisper.cpp: latest HEAD: ${WHISPER_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve stable-diffusion.cpp ref
|
||||
if [[ -n "${SD_REF:-}" ]]; then
|
||||
SD_HASH=$(resolve_ref "${SD_REPO}" "${SD_REF}") || exit 1
|
||||
echo "stable-diffusion.cpp: ${SD_REF} -> ${SD_HASH}"
|
||||
else
|
||||
SD_HASH=$(get_latest_hash "${SD_REPO}")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: latest HEAD: ${SD_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve ik_llama.cpp ref (CUDA only)
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
if [[ -n "${IK_LLAMA_REF:-}" ]]; then
|
||||
IK_LLAMA_HASH=$(resolve_ref "${IK_LLAMA_REPO}" "${IK_LLAMA_REF}") || exit 1
|
||||
echo "ik_llama.cpp: ${IK_LLAMA_REF} -> ${IK_LLAMA_HASH}"
|
||||
else
|
||||
IK_LLAMA_HASH=$(get_latest_hash "${IK_LLAMA_REPO}")
|
||||
if [[ -z "${IK_LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for ik_llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "ik_llama.cpp: latest HEAD: ${IK_LLAMA_HASH}"
|
||||
fi
|
||||
else
|
||||
IK_LLAMA_HASH="n/a"
|
||||
echo "ik_llama.cpp: skipped (vulkan build)"
|
||||
fi
|
||||
|
||||
# Resolve llama-swap ref
|
||||
if [[ -n "${LS_VERSION:-}" ]]; then
|
||||
LS_HASH=$(resolve_ref "${LLAMA_SWAP_REPO}" "${LS_VERSION}") || exit 1
|
||||
echo "llama-swap: ${LS_VERSION} -> ${LS_HASH}"
|
||||
else
|
||||
LS_HASH=$(get_latest_hash "${LLAMA_SWAP_REPO}")
|
||||
if [[ -z "${LS_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama-swap" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama-swap: latest HEAD: ${LS_HASH}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting Docker build..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
BUILD_ARGS=(
|
||||
--build-arg "BACKEND=${BACKEND}"
|
||||
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
|
||||
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
|
||||
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
|
||||
--build-arg "IK_LLAMA_COMMIT_HASH=${IK_LLAMA_HASH}"
|
||||
--build-arg "LS_VERSION=${LS_HASH}"
|
||||
--build-arg "RUN_UID=${RUN_UID:-0}"
|
||||
-t "${DOCKER_IMAGE_TAG}"
|
||||
-f "${SCRIPT_DIR}/Dockerfile"
|
||||
)
|
||||
|
||||
if [[ "$NO_CACHE" == true ]]; then
|
||||
BUILD_ARGS+=(--no-cache)
|
||||
echo "Note: Building without cache"
|
||||
elif [[ "${GITHUB_ACTIONS:-}" == "true" && "${ACT:-}" != "true" ]]; then
|
||||
CACHE_REF="ghcr.io/mostlygeek/llama-swap:unified-${BACKEND}-cache"
|
||||
BUILD_ARGS+=(
|
||||
--cache-from "type=registry,ref=${CACHE_REF}"
|
||||
--cache-to "type=registry,ref=${CACHE_REF},mode=max"
|
||||
)
|
||||
echo "Note: Using registry cache (${CACHE_REF})"
|
||||
fi
|
||||
|
||||
DOCKER_BUILDKIT=1 docker buildx build --load "${BUILD_ARGS[@]}" "${SCRIPT_DIR}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verifying build artifacts..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
EXPECTED_BINARIES=(llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap)
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
EXPECTED_BINARIES+=(ik-llama-server)
|
||||
fi
|
||||
|
||||
MISSING_BINARIES=()
|
||||
for binary in "${EXPECTED_BINARIES[@]}"; do
|
||||
if ! docker run --rm --entrypoint which "${DOCKER_IMAGE_TAG}" "${binary}" >/dev/null 2>&1; then
|
||||
MISSING_BINARIES+=("${binary}")
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
|
||||
echo "ERROR: Build succeeded but the following binaries are missing:"
|
||||
for binary in "${MISSING_BINARIES[@]}"; do
|
||||
echo " - ${binary}"
|
||||
done
|
||||
echo ""
|
||||
echo "Try running with --no-cache flag:"
|
||||
echo " ./build-image.sh --${BACKEND} --no-cache"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
VERIFIED_LIST="llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
VERIFIED_LIST="${VERIFIED_LIST}, ik-llama-server"
|
||||
fi
|
||||
echo "All expected binaries verified: ${VERIFIED_LIST}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Build complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Image tag: ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Built with:"
|
||||
echo " llama.cpp: ${LLAMA_HASH}"
|
||||
echo " whisper.cpp: ${WHISPER_HASH}"
|
||||
echo " stable-diffusion.cpp: ${SD_HASH}"
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
echo " ik_llama.cpp: ${IK_LLAMA_HASH}"
|
||||
fi
|
||||
echo " llama-swap: $(docker run --rm --entrypoint cat "${DOCKER_IMAGE_TAG}" /versions.txt | grep llama-swap | cut -d' ' -f2-)"
|
||||
echo ""
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Note: For AMD GPUs, you may also need:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
|
||||
else
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
|
||||
fi
|
||||
@@ -0,0 +1,33 @@
|
||||
# placeholder example configuration
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port ${PORT}
|
||||
|
||||
"whisper":
|
||||
checkEndpoint: /v1/audio/transcriptions/
|
||||
cmd: >
|
||||
whisper-server
|
||||
--port ${PORT}
|
||||
--m /models/whisper.bin
|
||||
--flash-attn
|
||||
--request-path /v1/audio/transcriptions --inference-path ""
|
||||
|
||||
"image":
|
||||
checkEndpoint: /
|
||||
cmd: |
|
||||
/app/sd-server
|
||||
--listen-port 9999
|
||||
--diffusion-fa
|
||||
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||
--vae /models/ae.safetensors
|
||||
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||
--offload-to-cpu
|
||||
--cfg-scale 1.0
|
||||
--height 512 --width 512
|
||||
--steps 8
|
||||
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Install ik_llama.cpp - clone, build, and install binaries
|
||||
# Usage: ./install-ik-llama.sh <commit_hash>
|
||||
# Note: CUDA only; always built against builder-base-cuda
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-main}"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted build dir doesn't break clone)
|
||||
echo "=== Cloning ik_llama.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/ik_llama.cpp
|
||||
cd /src/ik_llama.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ikawrakow/ik_llama.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DGGML_CUDA=ON
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda -Wl,--allow-shlib-undefined"
|
||||
)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building ik_llama.cpp ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target llama-server
|
||||
|
||||
if [ ! -f "build/bin/llama-server" ]; then
|
||||
echo "FATAL: llama-server not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install as ik-llama-server to avoid collision with llama.cpp's llama-server
|
||||
cp "build/bin/llama-server" "/install/bin/ik-llama-server"
|
||||
echo "=== ik_llama.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
# Install llama-swap - download latest release binary from GitHub
|
||||
# Usage: ./install-llama-swap.sh [version]
|
||||
# version: release version number (e.g., "170") or "latest" (default)
|
||||
set -e
|
||||
|
||||
VERSION="${1:-latest}"
|
||||
REPO="mostlygeek/llama-swap"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# If a full commit hash is given, find the release tag that points to it
|
||||
if echo "${VERSION}" | grep -qE '^[0-9a-f]{40}$'; then
|
||||
echo "=== Resolving commit ${VERSION:0:7} to release tag ==="
|
||||
TAG=$(git ls-remote --tags "https://github.com/${REPO}.git" 2>/dev/null \
|
||||
| grep "^${VERSION}" | sed 's|.*refs/tags/||' | grep -v '\^{}' | head -1)
|
||||
if [ -n "${TAG}" ]; then
|
||||
echo "Resolved to tag: ${TAG}"
|
||||
VERSION="${TAG#v}"
|
||||
else
|
||||
echo "No release tag found for commit ${VERSION:0:7}, using latest"
|
||||
VERSION="latest"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Strip leading 'v' prefix so both "198" and "v198" work
|
||||
VERSION="${VERSION#v}"
|
||||
|
||||
# Resolve "latest" to actual version number
|
||||
if [ "$VERSION" = "latest" ]; then
|
||||
echo "=== Resolving latest llama-swap release ==="
|
||||
VERSION=$(curl -fsSL "https://api.github.com/repos/${REPO}/releases/latest" \
|
||||
| grep '"tag_name"' | head -1 | cut -d'"' -f4 | sed 's/^v//')
|
||||
if [ -z "$VERSION" ]; then
|
||||
echo "FATAL: Could not determine latest release version" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Latest version: ${VERSION}"
|
||||
fi
|
||||
|
||||
# Download and extract
|
||||
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_amd64.tar.gz"
|
||||
echo "=== Downloading llama-swap v${VERSION} ==="
|
||||
echo "URL: $URL"
|
||||
curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
|
||||
tar -xzf /tmp/llama-swap.tar.gz -C /install/bin/
|
||||
rm /tmp/llama-swap.tar.gz
|
||||
|
||||
# Validate
|
||||
if [ ! -x "/install/bin/llama-swap" ]; then
|
||||
echo "FATAL: llama-swap binary not found or not executable" >&2
|
||||
ls -la /install/bin/ >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "$VERSION" > /install/llama-swap-version
|
||||
|
||||
echo "=== llama-swap v${VERSION} installed ==="
|
||||
ls -la /install/bin/llama-swap
|
||||
@@ -0,0 +1,63 @@
|
||||
#!/bin/bash
|
||||
# Install llama.cpp - clone, build, and install binaries
|
||||
# Usage: BACKEND=cuda|vulkan ./install-llama.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/llama.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning llama.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/llama.cpp
|
||||
cd /src/llama.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ggml-org/llama.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DLLAMA_BUILD_TESTS=OFF
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(llama-cli llama-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building llama.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in "${TARGETS[@]}"; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
echo "=== llama.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
# Install stable-diffusion.cpp - clone, build, and install binaries and library
|
||||
# Usage: BACKEND=cuda|vulkan ./install-sd.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin /install/lib
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/stable-diffusion.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning stable-diffusion.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/stable-diffusion.cpp
|
||||
cd /src/stable-diffusion.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/leejet/stable-diffusion.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
git submodule update --init --recursive --depth=1
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DSD_BUILD_EXAMPLES=ON
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
-DSD_CUDA=ON
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
-DSD_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(stable-diffusion sd-cli sd-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building stable-diffusion.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in sd-cli sd-server; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
|
||||
|
||||
echo "=== stable-diffusion.cpp build complete ==="
|
||||
ls -la /install/bin/ /install/lib/
|
||||
@@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
# Install whisper.cpp - clone, build, and install binaries
|
||||
# Usage: BACKEND=cuda|vulkan ./install-whisper.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin /install/lib
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/whisper.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning whisper.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/whisper.cpp
|
||||
cd /src/whisper.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ggml-org/whisper.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(whisper-cli whisper-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building whisper.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in "${TARGETS[@]}"; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
|
||||
|
||||
echo "=== whisper.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
|
Before Width: | Height: | Size: 261 KiB After Width: | Height: | Size: 261 KiB |
|
Before Width: | Height: | Size: 351 KiB After Width: | Height: | Size: 351 KiB |
|
After Width: | Height: | Size: 198 KiB |
@@ -86,9 +86,12 @@ llama-swap supports many more features to customize how you want to manage your
|
||||
## Full Configuration Example
|
||||
|
||||
> [!NOTE]
|
||||
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples.
|
||||
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
|
||||
|
||||
```yaml
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
@@ -114,6 +117,24 @@ healthCheckTimeout: 500
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
@@ -126,6 +147,30 @@ metricsMaxInMemory: 1000
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||
# field
|
||||
# - optional, default: false
|
||||
# - when true, a stream of loading messages will be sent to the client in the
|
||||
# reasoning field so chat UIs can show that loading is in progress.
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
@@ -274,6 +319,33 @@ models:
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# timeouts: configure proxy connection timeouts for this model
|
||||
# - optional, defaults shown below
|
||||
# - useful for models on slower hardware that need longer timeouts
|
||||
# - increase responseHeader to avoid "timeout awaiting response headers" errors
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
# connect: TCP connection timeout in seconds
|
||||
# - default: 30
|
||||
connect: 30
|
||||
|
||||
# responseHeader: time to wait for response headers in seconds
|
||||
# - default: 60
|
||||
# - for slow image generation or large models, consider increasing to 300+ seconds
|
||||
responseHeader: 60
|
||||
|
||||
# tlsHandshake: TLS handshake timeout in seconds
|
||||
# - default: 10
|
||||
tlsHandshake: 10
|
||||
|
||||
# idleConn: idle connection timeout in seconds
|
||||
# - default: 90
|
||||
idleConn: 90
|
||||
|
||||
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
@@ -383,4 +455,47 @@ hooks:
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
|
||||
# timeouts: configure proxy connection timeouts for this peer
|
||||
# - optional, defaults shown below
|
||||
# - useful when the peer runs on slower hardware
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
apiKey: sk-your-openrouter-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
```
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
## Container Security
|
||||
|
||||
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
|
||||
|
||||
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
|
||||
|
||||
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
|
||||
|
||||
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/mostlygeek/llama-swap
|
||||
|
||||
go 1.23.0
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
@@ -37,9 +37,9 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -80,16 +80,16 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
const (
|
||||
LogToStdoutProxy = "proxy"
|
||||
LogToStdoutUpstream = "upstream"
|
||||
LogToStdoutBoth = "both"
|
||||
LogToStdoutNone = "none"
|
||||
)
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
@@ -81,6 +87,7 @@ type GroupConfig struct {
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
@@ -114,7 +121,10 @@ type Config struct {
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||
LogToStdout string `yaml:"logToStdout"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
CaptureBuffer int `yaml:"captureBuffer"`
|
||||
GlobalTTL int `yaml:"globalTTL"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
@@ -136,6 +146,12 @@ type Config struct {
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
|
||||
// support API keys, see issue #433, #50, #251
|
||||
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -170,22 +186,31 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// default configuration values
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
@@ -193,6 +218,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
@@ -204,55 +239,55 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT, MODEL_ID
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// validate model macros
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Add global macros first
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (can override global)
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -262,23 +297,40 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
@@ -287,29 +339,25 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Final pass: check if PORT macro is needed after macro expansion
|
||||
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||
// if it is required in either cmd or proxy keys
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort { // either has it
|
||||
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro and substitute it
|
||||
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute PORT in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
@@ -319,13 +367,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
@@ -333,35 +383,55 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
continue // replaced at runtime
|
||||
}
|
||||
// Reserved macros are always valid (they should have been substituted already)
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
// Any other macro is unknown
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unknown macros in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the proxy URL.
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf(
|
||||
"model %s: invalid proxy URL: %w", modelId, err,
|
||||
)
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
// if sendLoadingState is nil, set it to the global config value
|
||||
// see #366
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState // copy it
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
@@ -369,18 +439,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
@@ -388,7 +457,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// clean up hooks preload
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
@@ -400,10 +469,56 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -534,20 +649,26 @@ func validateMacro(name string, value any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -555,7 +676,7 @@ func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -614,3 +735,67 @@ func substituteMacroInValue(value any, macroName string, macroValue any) (any, e
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ groups:
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
@@ -186,6 +187,13 @@ groups:
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
@@ -194,6 +202,13 @@ groups:
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -202,6 +217,13 @@ groups:
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -210,10 +232,18 @@ groups:
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||
@@ -761,3 +762,785 @@ models:
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
content: `apiKeys: [""]`,
|
||||
expectedErr: "empty api key found in apiKeys",
|
||||
},
|
||||
{
|
||||
name: "blank spaces only",
|
||||
content: `apiKeys: [" "]`,
|
||||
expectedErr: "api key cannot contain spaces: ` `",
|
||||
},
|
||||
{
|
||||
name: "contains leading space",
|
||||
content: `apiKeys: [" key123"]`,
|
||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||
},
|
||||
{
|
||||
name: "contains trailing space",
|
||||
content: `apiKeys: ["key123 "]`,
|
||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||
},
|
||||
{
|
||||
name: "contains middle space",
|
||||
content: `apiKeys: ["key 123"]`,
|
||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||
},
|
||||
{
|
||||
name: "empty in list with valid keys",
|
||||
content: `apiKeys: ["valid-key", "", "another-key"]`,
|
||||
expectedErr: "empty api key found in apiKeys",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, tt.expectedErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_APIKeys_EnvMacros(t *testing.T) {
|
||||
t.Run("env substitution in apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "secret-key-123")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_API_KEY}"]`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"secret-key-123"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("multiple env substitutions in apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY_1", "key-one")
|
||||
t.Setenv("TEST_API_KEY_2", "key-two")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_API_KEY_1}", "${env.TEST_API_KEY_2}", "static-key"]`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"key-one", "key-two", "static-key"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("missing env var in apiKeys", func(t *testing.T) {
|
||||
content := `apiKeys: ["${env.NONEXISTENT_API_KEY}"]`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
// With string-level env substitution, error only includes var name
|
||||
assert.Contains(t, err.Error(), "NONEXISTENT_API_KEY")
|
||||
})
|
||||
|
||||
t.Run("env substitution results in empty key", func(t *testing.T) {
|
||||
t.Setenv("TEST_EMPTY_KEY", "")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_EMPTY_KEY}"]`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "empty api key found in apiKeys", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_GlobalTTL(t *testing.T) {
|
||||
t.Run("globalTTL sets default for models", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 300, config.GlobalTTL)
|
||||
assert.Equal(t, 300, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("model ttl=0 overrides globalTTL", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
ttl: 0
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("model explicit ttl overrides globalTTL", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
ttl: 600
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 600, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("globalTTL defaults to 0", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, config.GlobalTTL)
|
||||
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("negative globalTTL rejected", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: -1
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "globalTTL must be >= 0")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_EnvMacros(t *testing.T) {
|
||||
t.Run("basic env substitution in cmd", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_PATH", "/opt/models")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "${env.TEST_MODEL_PATH}/llama-server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/opt/models/llama-server", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env substitution in multiple fields", func(t *testing.T) {
|
||||
t.Setenv("TEST_HOST", "myserver")
|
||||
t.Setenv("TEST_PORT", "9999")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --host ${env.TEST_HOST}"
|
||||
proxy: "http://${env.TEST_HOST}:${env.TEST_PORT}"
|
||||
checkEndpoint: "http://${env.TEST_HOST}/health"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --host myserver", config.Models["test"].Cmd)
|
||||
assert.Equal(t, "http://myserver:9999", config.Models["test"].Proxy)
|
||||
assert.Equal(t, "http://myserver/health", config.Models["test"].CheckEndpoint)
|
||||
})
|
||||
|
||||
t.Run("env in global macro value", func(t *testing.T) {
|
||||
t.Setenv("TEST_BASE_PATH", "/usr/local")
|
||||
|
||||
content := `
|
||||
macros:
|
||||
SERVER_PATH: "${env.TEST_BASE_PATH}/bin/server"
|
||||
models:
|
||||
test:
|
||||
cmd: "${SERVER_PATH} --port 8080"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/usr/local/bin/server --port 8080", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env in model-level macro value", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_DIR", "/models/llama")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
MODEL_FILE: "${env.TEST_MODEL_DIR}/model.gguf"
|
||||
cmd: "server --model ${MODEL_FILE}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --model /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env in metadata", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "secret123")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
metadata:
|
||||
api_key: "${env.TEST_API_KEY}"
|
||||
nested:
|
||||
key: "${env.TEST_API_KEY}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "secret123", config.Models["test"].Metadata["api_key"])
|
||||
nested := config.Models["test"].Metadata["nested"].(map[string]any)
|
||||
assert.Equal(t, "secret123", nested["key"])
|
||||
})
|
||||
|
||||
t.Run("env in filters.stripParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_STRIP_PARAMS", "temperature,top_p")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
filters:
|
||||
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "temperature,top_p", config.Models["test"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("env in cmdStop", func(t *testing.T) {
|
||||
t.Setenv("TEST_KILL_SIGNAL", "SIGTERM")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --port ${PORT}"
|
||||
cmdStop: "kill -${env.TEST_KILL_SIGNAL} ${PID}"
|
||||
proxy: "http://localhost:${PORT}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, config.Models["test"].CmdStop, "-SIGTERM")
|
||||
})
|
||||
|
||||
t.Run("missing env var returns error", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "${env.UNDEFINED_VAR_12345}/server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_VAR_12345")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in global macro", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
PATH: "${env.UNDEFINED_GLOBAL_VAR}"
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_GLOBAL_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in model macro", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
MY_PATH: "${env.UNDEFINED_MODEL_VAR}"
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_MODEL_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in metadata", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
metadata:
|
||||
key: "${env.UNDEFINED_META_VAR}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_META_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env combined with regular macros", func(t *testing.T) {
|
||||
t.Setenv("TEST_ROOT", "/data")
|
||||
|
||||
content := `
|
||||
macros:
|
||||
MODEL_BASE: "${env.TEST_ROOT}/models"
|
||||
models:
|
||||
test:
|
||||
cmd: "server --model ${MODEL_BASE}/${MODEL_ID}.gguf"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --model /data/models/test.gguf", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("multiple env vars in same string", func(t *testing.T) {
|
||||
t.Setenv("TEST_USER", "admin")
|
||||
t.Setenv("TEST_PASS", "secret")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --auth ${env.TEST_USER}:${env.TEST_PASS}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --auth admin:secret", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env value with newline is rejected", func(t *testing.T) {
|
||||
t.Setenv("TEST_MULTILINE", "line1\nline2")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --config ${env.TEST_MULTILINE}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "TEST_MULTILINE")
|
||||
assert.Contains(t, err.Error(), "newlines")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env value with carriage return is rejected", func(t *testing.T) {
|
||||
t.Setenv("TEST_CR", "line1\rline2")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --config ${env.TEST_CR}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "TEST_CR")
|
||||
assert.Contains(t, err.Error(), "newlines")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env value with quotes is escaped for YAML", func(t *testing.T) {
|
||||
t.Setenv("TEST_QUOTED", `value with "quotes"`)
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --arg \"${env.TEST_QUOTED}\""
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
// Quotes are escaped before YAML parsing, then YAML unescapes them
|
||||
// Final result preserves the original value with quotes
|
||||
assert.Contains(t, config.Models["test"].Cmd, `"quotes"`)
|
||||
})
|
||||
|
||||
t.Run("env value with backslash is escaped for YAML", func(t *testing.T) {
|
||||
t.Setenv("TEST_BACKSLASH", `path\to\file`)
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --path \"${env.TEST_BACKSLASH}\""
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
// Backslashes are escaped before YAML parsing, then YAML unescapes them
|
||||
// Final result preserves the original value with backslashes
|
||||
assert.Contains(t, config.Models["test"].Cmd, `path\to\file`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_PeerApiKey_EnvMacros(t *testing.T) {
|
||||
t.Run("env substitution in peer apiKey", func(t *testing.T) {
|
||||
t.Setenv("TEST_PEER_API_KEY", "sk-peer-secret-123")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${env.TEST_PEER_API_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-peer-secret-123", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("missing env var in peer apiKey", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${env.NONEXISTENT_PEER_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
// With string-level env substitution, error only includes var name
|
||||
assert.Contains(t, err.Error(), "NONEXISTENT_PEER_KEY")
|
||||
})
|
||||
|
||||
t.Run("static apiKey unchanged", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-static-key
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-static-key", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("multiple peers with env apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_PEER_KEY_1", "key-one")
|
||||
t.Setenv("TEST_PEER_KEY_2", "key-two")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
peer1:
|
||||
proxy: https://peer1.example.com
|
||||
apiKey: "${env.TEST_PEER_KEY_1}"
|
||||
models:
|
||||
- model-a
|
||||
peer2:
|
||||
proxy: https://peer2.example.com
|
||||
apiKey: "${env.TEST_PEER_KEY_2}"
|
||||
models:
|
||||
- model-b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "key-one", config.Peers["peer1"].ApiKey)
|
||||
assert.Equal(t, "key-two", config.Peers["peer2"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("global macro substitution in peer apiKey", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
API_KEY: sk-from-global-macro
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${API_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-from-global-macro", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("global macro in peer filters.stripParams", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
STRIP_LIST: "temperature, top_p"
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
stripParams: "${STRIP_LIST}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "temperature, top_p", config.Peers["openrouter"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("global macro in peer filters.setParams", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
MAX_TOKENS: 4096
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
max_tokens: "${MAX_TOKENS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4096, config.Peers["openrouter"].Filters.SetParams["max_tokens"])
|
||||
})
|
||||
|
||||
t.Run("env macro in peer filters.setParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_RETENTION_POLICY", "deny")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
data_collection: "${env.TEST_RETENTION_POLICY}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "deny", config.Peers["openrouter"].Filters.SetParams["data_collection"])
|
||||
})
|
||||
|
||||
t.Run("env macro in peer filters.stripParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_STRIP_PARAMS", "frequency_penalty, presence_penalty")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "frequency_penalty, presence_penalty", config.Peers["openrouter"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("unknown macro in peer apiKey fails", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${UNDEFINED_MACRO}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peers.openrouter.apiKey")
|
||||
assert.Contains(t, err.Error(), "unknown macro")
|
||||
})
|
||||
|
||||
t.Run("unknown macro in peer filters.setParams fails", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
value: "${UNDEFINED_MACRO}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peers.openrouter.filters.setParams")
|
||||
assert.Contains(t, err.Error(), "unknown macro")
|
||||
})
|
||||
|
||||
t.Run("env macros in comments are ignored", func(t *testing.T) {
|
||||
content := `
|
||||
# apiKeys:
|
||||
# - "${env.COMMENTED_OUT_KEY_1}"
|
||||
# - "${env.COMMENTED_OUT_KEY_2}"
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
// These env vars are NOT set, but should not cause an error
|
||||
// because they only appear in comment lines
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("env macros in comments ignored while active ones resolve", func(t *testing.T) {
|
||||
t.Setenv("TEST_ACTIVE_KEY", "active-key-value")
|
||||
|
||||
content := `
|
||||
# apiKeys: ["${env.COMMENTED_OUT_KEY}"]
|
||||
apiKeys: ["${env.TEST_ACTIVE_KEY}"]
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"active-key-value"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("env macros in indented comments are ignored", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: |
|
||||
server
|
||||
--port 8080
|
||||
proxy: "http://localhost:8080"
|
||||
# metadata:
|
||||
# api_key: "${env.SOME_UNSET_KEY}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("env macros in inline comments are ignored", func(t *testing.T) {
|
||||
t.Setenv("TEST_INLINE_KEY", "real-value")
|
||||
|
||||
content := `
|
||||
apiKeys: ["${env.TEST_INLINE_KEY}"] # TODO: add ${env.FUTURE_KEY} later
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"real-value"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestConfig_TimeoutsParsing(t *testing.T) {
|
||||
configYaml := `
|
||||
models:
|
||||
model1:
|
||||
cmd: test-server --port ${PORT}
|
||||
timeouts:
|
||||
connect: 45
|
||||
responseHeader: 120
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
modelConfig, found := config.Models["model1"]
|
||||
require.True(t, found, "model1 should exist in config")
|
||||
|
||||
assert.Equal(t, 45, modelConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 120, modelConfig.Timeouts.ResponseHeader)
|
||||
}
|
||||
|
||||
func TestConfig_TimeoutsDefaults(t *testing.T) {
|
||||
configYaml := `
|
||||
models:
|
||||
model1:
|
||||
cmd: test-server --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
modelConfig, found := config.Models["model1"]
|
||||
require.True(t, found, "model1 should exist in config")
|
||||
|
||||
// Default values should be set during unmarshaling
|
||||
assert.Equal(t, 30, modelConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 60, modelConfig.Timeouts.ResponseHeader)
|
||||
assert.Equal(t, 10, modelConfig.Timeouts.TLSHandshake)
|
||||
assert.Equal(t, 1, modelConfig.Timeouts.ExpectContinue)
|
||||
assert.Equal(t, 90, modelConfig.Timeouts.IdleConn)
|
||||
}
|
||||
|
||||
func TestConfig_TimeoutsZeroAllowed(t *testing.T) {
|
||||
configYaml := `
|
||||
models:
|
||||
model1:
|
||||
cmd: test-server --port ${PORT}
|
||||
timeouts:
|
||||
connect: 0
|
||||
responseHeader: 0
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
modelConfig, found := config.Models["model1"]
|
||||
require.True(t, found, "model1 should exist in config")
|
||||
|
||||
// Explicit 0 should be preserved (disables timeout)
|
||||
assert.Equal(t, 0, modelConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 0, modelConfig.Timeouts.ResponseHeader)
|
||||
}
|
||||
|
||||
func TestConfig_PeerTimeoutsParsing(t *testing.T) {
|
||||
configYaml := `
|
||||
peers:
|
||||
peer1:
|
||||
proxy: http://example.com
|
||||
models: [model1]
|
||||
timeouts:
|
||||
connect: 45
|
||||
responseHeader: 120
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
peerConfig, found := config.Peers["peer1"]
|
||||
require.True(t, found, "peer1 should exist in config")
|
||||
|
||||
assert.Equal(t, 45, peerConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 120, peerConfig.Timeouts.ResponseHeader)
|
||||
}
|
||||
|
||||
func TestConfig_PeerTimeoutsDefaults(t *testing.T) {
|
||||
configYaml := `
|
||||
peers:
|
||||
peer1:
|
||||
proxy: http://example.com
|
||||
models: [model1]
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||
require.NoError(t, err)
|
||||
|
||||
peerConfig, found := config.Peers["peer1"]
|
||||
require.True(t, found, "peer1 should exist in config")
|
||||
|
||||
// Default values should be set during unmarshaling
|
||||
assert.Equal(t, 30, peerConfig.Timeouts.Connect)
|
||||
assert.Equal(t, 60, peerConfig.Timeouts.ResponseHeader)
|
||||
assert.Equal(t, 10, peerConfig.Timeouts.TLSHandshake)
|
||||
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
|
||||
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
|
||||
}
|
||||
|
||||
@@ -158,6 +158,7 @@ groups:
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
@@ -172,6 +173,13 @@ groups:
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
@@ -181,6 +189,13 @@ groups:
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -190,6 +205,13 @@ groups:
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
@@ -199,10 +221,18 @@ groups:
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||
var ProtectedParams = []string{"model"}
|
||||
|
||||
// Filters contains filter settings for modifying request parameters
|
||||
// Used by both models and peers
|
||||
type Filters struct {
|
||||
// StripParams is a comma-separated list of parameters to remove from requests
|
||||
// The "model" parameter can never be removed
|
||||
StripParams string `yaml:"stripParams"`
|
||||
|
||||
// SetParams is a dictionary of parameters to set/override in requests
|
||||
// Protected params (like "model") cannot be set
|
||||
SetParams map[string]any `yaml:"setParams"`
|
||||
|
||||
// SetParamsByID maps requested model IDs to parameters to set/override in requests.
|
||||
// Useful with aliases: a single loaded model can behave differently depending on
|
||||
// which alias the client used. Applied after SetParams, so it can override those values.
|
||||
// Protected params (like "model") cannot be set.
|
||||
SetParamsByID map[string]map[string]any `yaml:"setParamsByID"`
|
||||
}
|
||||
|
||||
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||
// with duplicates, empty strings, and protected params removed
|
||||
func (f Filters) SanitizedStripParams() []string {
|
||||
if f.StripParams == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
// Skip protected params, empty strings, and duplicates
|
||||
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slices.Sort(cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// SanitizedSetParamsByID returns the params to set for the given requestedModelID,
|
||||
// with protected params removed and keys sorted for consistent iteration order.
|
||||
// Returns nil if the ID has no entry or all its params are protected.
|
||||
func (f Filters) SanitizedSetParamsByID(requestedModelID string) (map[string]any, []string) {
|
||||
if len(f.SetParamsByID) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
params, found := f.SetParamsByID[requestedModelID]
|
||||
if !found || len(params) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
result := make(map[string]any, len(params))
|
||||
keys := make([]string, 0, len(params))
|
||||
for key, value := range params {
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return result, keys
|
||||
}
|
||||
|
||||
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||
// and keys sorted for consistent iteration order
|
||||
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||
if len(f.SetParams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any, len(f.SetParams))
|
||||
keys := make([]string, 0, len(f.SetParams))
|
||||
|
||||
for key, value := range f.SetParams {
|
||||
// Skip protected params
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Sort keys for consistent ordering
|
||||
sort.Strings(keys)
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, keys
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripParams string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
stripParams: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single param",
|
||||
stripParams: "temperature",
|
||||
want: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
stripParams: "temperature, top_p, top_k",
|
||||
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||
},
|
||||
{
|
||||
name: "model param filtered",
|
||||
stripParams: "model, temperature, top_p",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "only model param",
|
||||
stripParams: "model",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "duplicates removed",
|
||||
stripParams: "temperature, top_p, temperature",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace",
|
||||
stripParams: " temperature , top_p ",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "empty values filtered",
|
||||
stripParams: "temperature,,top_p,",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{StripParams: tt.stripParams}
|
||||
got := f.SanitizedStripParams()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParams map[string]any
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty setParams",
|
||||
setParams: nil,
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
setParams: map[string]any{},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "normal params",
|
||||
setParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected model param filtered",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "complex nested values",
|
||||
setParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantKeys: []string{"provider", "transforms"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParams: tt.setParams}
|
||||
gotParams, gotKeys := f.SanitizedSetParams()
|
||||
|
||||
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||
for i, key := range gotKeys {
|
||||
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||
}
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams, "expected nil params")
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||
for key, wantValue := range tt.wantParams {
|
||||
gotValue, exists := gotParams[key]
|
||||
assert.True(t, exists, "missing key: %s", key)
|
||||
// Simple comparison for basic types
|
||||
switch v := wantValue.(type) {
|
||||
case string, int, float64, bool:
|
||||
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParamsByID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParamsByID map[string]map[string]any
|
||||
requestedModelID string
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty SetParamsByID returns nil",
|
||||
setParamsByID: nil,
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map returns nil",
|
||||
setParamsByID: map[string]map[string]any{},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "non-matching model ID returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model2": {"temperature": 0.9},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "matching model ID returns correct params",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {"temperature": 0.7, "top_p": 0.9},
|
||||
"model2": {"temperature": 0.5},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected param model is filtered out",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "keys are sorted",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
wantKeys: []string{"a_param", "m_param", "z_param"},
|
||||
},
|
||||
{
|
||||
name: "alias style key lookup",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1:high": {"reasoning_effort": "high"},
|
||||
"model1:low": {"reasoning_effort": "low"},
|
||||
},
|
||||
requestedModelID: "model1:high",
|
||||
wantParams: map[string]any{
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
wantKeys: []string{"reasoning_effort"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParamsByID: tt.setParamsByID}
|
||||
gotParams, gotKeys := f.SanitizedSetParamsByID(tt.requestedModelID)
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams)
|
||||
assert.Nil(t, gotKeys)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantKeys, gotKeys)
|
||||
assert.Equal(t, tt.wantParams, gotParams)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedParams(t *testing.T) {
|
||||
// Verify that "model" is protected
|
||||
assert.Contains(t, ProtectedParams, "model")
|
||||
}
|
||||
@@ -104,6 +104,62 @@ models:
|
||||
assert.Contains(t, err.Error(), "self-reference")
|
||||
}
|
||||
|
||||
// Test macro substitution in name and description fields
|
||||
func TestConfig_MacroInNameAndDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"VARIANT": "Q4_K_M"
|
||||
"FAMILY": "llama"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "${FAMILY} ${VARIANT}"
|
||||
description: "A ${FAMILY} model in ${VARIANT} format"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "llama Q4_K_M", config.Models["my-model"].Name)
|
||||
assert.Equal(t, "A llama model in Q4_K_M format", config.Models["my-model"].Description)
|
||||
}
|
||||
|
||||
// Test MODEL_ID macro in name and description fields
|
||||
func TestConfig_ModelIDInNameAndDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
llama-3b:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "Model: ${MODEL_ID}"
|
||||
description: "Running ${MODEL_ID}"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Model: llama-3b", config.Models["llama-3b"].Name)
|
||||
assert.Equal(t, "Running llama-3b", config.Models["llama-3b"].Description)
|
||||
}
|
||||
|
||||
// Test unknown macro in name or description returns an error
|
||||
func TestConfig_UnknownMacroInNameDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
test:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "Model ${UNDEFINED}"
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
|
||||
// Test undefined macro reference error
|
||||
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||
content := `
|
||||
|
||||
@@ -3,10 +3,21 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||
)
|
||||
|
||||
// TimeoutsConfig holds timeout settings for proxy connections
|
||||
type TimeoutsConfig struct {
|
||||
Connect int `yaml:"connect"` // seconds, 0 = no timeout (not recommended)
|
||||
ResponseHeader int `yaml:"responseHeader"` // seconds, 0 = no timeout (not recommended)
|
||||
TLSHandshake int `yaml:"tlsHandshake"` // seconds, 0 = no timeout (not recommended)
|
||||
ExpectContinue int `yaml:"expectContinue"` // seconds, 0 = no timeout (not recommended)
|
||||
IdleConn int `yaml:"idleConn"` // seconds, 0 = no timeout (not recommended)
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
@@ -38,6 +49,9 @@ type ModelConfig struct {
|
||||
|
||||
// override global setting
|
||||
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
@@ -49,12 +63,19 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
@@ -74,16 +95,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||
// See issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"stripParams"`
|
||||
Filters `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
defaults := rawModelFilters{}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
@@ -104,25 +124,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||
// Returns ([]string, error) to match existing API
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
return f.Filters.SanitizedStripParams(), nil
|
||||
}
|
||||
|
||||
@@ -72,3 +72,101 @@ models:
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAlias(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"${MODEL_ID}:high":
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
reasoning_effort: low
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Keys (other than the model's own ID) should be registered as aliases
|
||||
realName, found := cfg.RealModelName("model1:high")
|
||||
assert.True(t, found, "model1:high should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
realName, found = cfg.RealModelName("model1:low")
|
||||
assert.True(t, found, "model1:low should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
// Auto-aliases should also appear in modelConfig.Aliases
|
||||
aliases := cfg.Models["model1"].Aliases
|
||||
assert.Contains(t, aliases, "model1:high")
|
||||
assert.Contains(t, aliases, "model1:low")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithModelID(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
model2:
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "conflicts with an existing model ID")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithOtherModel(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: low
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "duplicate alias")
|
||||
}
|
||||
|
||||
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
stripParams: "top_k"
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelConfig := config.Models["model1"]
|
||||
|
||||
// Check stripParams
|
||||
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||
|
||||
// Check setParams
|
||||
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||
assert.NotNil(t, setParams)
|
||||
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type PeerDictionaryConfig map[string]PeerConfig
|
||||
type PeerConfig struct {
|
||||
Proxy string `yaml:"proxy"`
|
||||
ProxyURL *url.URL `yaml:"-"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Models []string `yaml:"models"`
|
||||
Filters Filters `yaml:"filters"`
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
}
|
||||
|
||||
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawPeerConfig PeerConfig
|
||||
defaults := rawPeerConfig{
|
||||
Proxy: "",
|
||||
ApiKey: "",
|
||||
Models: []string{},
|
||||
Filters: Filters{},
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate proxy is not empty
|
||||
if defaults.Proxy == "" {
|
||||
return fmt.Errorf("proxy is required")
|
||||
}
|
||||
|
||||
// Validate proxy is a valid URL and store the parsed value
|
||||
parsedURL, err := url.Parse(defaults.Proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||
}
|
||||
defaults.ProxyURL = parsedURL
|
||||
|
||||
// Validate models is not empty
|
||||
if len(defaults.Models) == 0 {
|
||||
return fmt.Errorf("peer models can not be empty")
|
||||
}
|
||||
|
||||
*c = PeerConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
yaml: `
|
||||
proxy: http://192.168.1.23
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid config with apiKey",
|
||||
yaml: `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "missing proxy",
|
||||
yaml: `
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "empty proxy",
|
||||
yaml: `
|
||||
proxy: ""
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "invalid proxy URL",
|
||||
yaml: `
|
||||
proxy: "://invalid"
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "invalid peer proxy URL",
|
||||
},
|
||||
{
|
||||
name: "missing models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
models: []
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||
} else if !contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: http://192.168.1.23:8080/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.ProxyURL == nil {
|
||||
t.Fatal("ProxyURL should not be nil")
|
||||
}
|
||||
|
||||
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Scheme != "http" {
|
||||
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Path != "/api" {
|
||||
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||
}
|
||||
|
||||
func searchSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
provider:
|
||||
data_collection: deny
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
|
||||
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||
}
|
||||
|
||||
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider should be a map")
|
||||
}
|
||||
if provider["data_collection"] != "deny" {
|
||||
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
stripParams: "temperature, top_p"
|
||||
setParams:
|
||||
max_tokens: 1000
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check stripParams
|
||||
stripParams := config.Filters.SanitizedStripParams()
|
||||
if len(stripParams) != 2 {
|
||||
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||
}
|
||||
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||
t.Errorf("unexpected strip params: %v", stripParams)
|
||||
}
|
||||
|
||||
// Check setParams
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
@@ -58,3 +59,11 @@ type ModelPreloadedEvent struct {
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
|
||||
@@ -71,11 +71,15 @@ func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Convert path to forward slashes for cross-platform compatibility
|
||||
// Windows handles forward slashes in paths correctly
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, simpleResponderPath, port, expectedMessage, port)
|
||||
`, cmdPath, port, expectedMessage, port)
|
||||
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,6 +11,85 @@ import (
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
// circularBuffer is a fixed-size circular byte buffer that overwrites
|
||||
// oldest data when full. It provides O(1) writes and O(n) reads.
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
|
||||
func newCircularBuffer(capacity int) *circularBuffer {
|
||||
return &circularBuffer{
|
||||
data: make([]byte, capacity),
|
||||
head: 0,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Write appends bytes to the buffer, overwriting oldest data when full.
|
||||
// Data is copied into the internal buffer (not stored by reference).
|
||||
func (cb *circularBuffer) Write(p []byte) {
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cap := len(cb.data)
|
||||
|
||||
// If input is larger than capacity, only keep the last cap bytes
|
||||
if len(p) >= cap {
|
||||
copy(cb.data, p[len(p)-cap:])
|
||||
cb.head = 0
|
||||
cb.size = cap
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate how much space is available from head to end of buffer
|
||||
firstPart := cap - cb.head
|
||||
if firstPart >= len(p) {
|
||||
// All data fits without wrapping
|
||||
copy(cb.data[cb.head:], p)
|
||||
cb.head = (cb.head + len(p)) % cap
|
||||
} else {
|
||||
// Data wraps around
|
||||
copy(cb.data[cb.head:], p[:firstPart])
|
||||
copy(cb.data[:len(p)-firstPart], p[firstPart:])
|
||||
cb.head = len(p) - firstPart
|
||||
}
|
||||
|
||||
// Update size
|
||||
cb.size += len(p)
|
||||
if cb.size > cap {
|
||||
cb.size = cap
|
||||
}
|
||||
}
|
||||
|
||||
// GetHistory returns all buffered data in correct order (oldest to newest).
|
||||
// Returns a new slice (copy), not a view into internal buffer.
|
||||
func (cb *circularBuffer) GetHistory() []byte {
|
||||
if cb.size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]byte, cb.size)
|
||||
cap := len(cb.data)
|
||||
|
||||
// Calculate start position (oldest data)
|
||||
start := (cb.head - cb.size + cap) % cap
|
||||
|
||||
if start+cb.size <= cap {
|
||||
// Data is contiguous, single copy
|
||||
copy(result, cb.data[start:start+cb.size])
|
||||
} else {
|
||||
// Data wraps around, two copies
|
||||
firstPart := cap - start
|
||||
copy(result[:firstPart], cb.data[start:])
|
||||
copy(result[firstPart:], cb.data[:cb.size-firstPart])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
@@ -19,12 +97,14 @@ const (
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
|
||||
LogBufferSize = 100 * 1024
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
eventbus *event.Dispatcher
|
||||
mu sync.RWMutex
|
||||
buffer *ring.Ring
|
||||
buffer *circularBuffer
|
||||
bufferMu sync.RWMutex
|
||||
|
||||
// typically this can be os.Stdout
|
||||
@@ -45,7 +125,7 @@ func NewLogMonitor() *LogMonitor {
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
buffer: nil, // lazy initialized on first Write
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
@@ -64,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
w.bufferMu.Lock()
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.buffer.Value = bufferCopy
|
||||
w.buffer = w.buffer.Next()
|
||||
if w.buffer == nil {
|
||||
w.buffer = newCircularBuffer(LogBufferSize)
|
||||
}
|
||||
w.buffer.Write(p)
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
// Make a copy for broadcast to preserve immutability
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.broadcast(bufferCopy)
|
||||
return n, nil
|
||||
}
|
||||
@@ -77,16 +160,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
func (w *LogMonitor) GetHistory() []byte {
|
||||
w.bufferMu.RLock()
|
||||
defer w.bufferMu.RUnlock()
|
||||
if w.buffer == nil {
|
||||
return nil
|
||||
}
|
||||
return w.buffer.GetHistory()
|
||||
}
|
||||
|
||||
var history []byte
|
||||
w.buffer.Do(func(p any) {
|
||||
if p != nil {
|
||||
if content, ok := p.([]byte); ok {
|
||||
history = append(history, content...)
|
||||
}
|
||||
}
|
||||
})
|
||||
return history
|
||||
// Clear releases the buffer memory, making it eligible for GC.
|
||||
// The buffer will be lazily re-allocated on the next Write.
|
||||
func (w *LogMonitor) Clear() {
|
||||
w.bufferMu.Lock()
|
||||
w.buffer = nil
|
||||
w.bufferMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||
|
||||
@@ -113,3 +113,204 @@ func TestWrite_LogTimeFormat(t *testing.T) {
|
||||
t.Fatalf("Cannot find timestamp: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_WrapAround(t *testing.T) {
|
||||
// Create a small buffer to test wrap-around
|
||||
cb := newCircularBuffer(10)
|
||||
|
||||
// Write "hello" (5 bytes)
|
||||
cb.Write([]byte("hello"))
|
||||
if got := string(cb.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Write "world" (5 bytes) - buffer now full
|
||||
cb.Write([]byte("world"))
|
||||
if got := string(cb.GetHistory()); got != "helloworld" {
|
||||
t.Errorf("Expected 'helloworld', got %q", got)
|
||||
}
|
||||
|
||||
// Write "12345" (5 bytes) - should overwrite "hello"
|
||||
cb.Write([]byte("12345"))
|
||||
if got := string(cb.GetHistory()); got != "world12345" {
|
||||
t.Errorf("Expected 'world12345', got %q", got)
|
||||
}
|
||||
|
||||
// Write data larger than buffer capacity
|
||||
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
|
||||
if got := string(cb.GetHistory()); got != "ghijklmnop" {
|
||||
t.Errorf("Expected 'ghijklmnop', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
|
||||
// Test empty buffer
|
||||
cb := newCircularBuffer(10)
|
||||
if got := cb.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil for empty buffer, got %q", got)
|
||||
}
|
||||
|
||||
// Test exact capacity
|
||||
cb.Write([]byte("1234567890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
|
||||
// Test write exactly at capacity boundary
|
||||
cb = newCircularBuffer(10)
|
||||
cb.Write([]byte("12345"))
|
||||
cb.Write([]byte("67890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_LazyInit(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Buffer should be nil before any writes
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil before first write")
|
||||
}
|
||||
|
||||
// GetHistory should return nil when buffer is nil
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history before first write, got %q", got)
|
||||
}
|
||||
|
||||
// Write should lazily initialize the buffer
|
||||
lm.Write([]byte("test"))
|
||||
|
||||
if lm.buffer == nil {
|
||||
t.Error("Expected buffer to be initialized after write")
|
||||
}
|
||||
|
||||
if got := string(lm.GetHistory()); got != "test" {
|
||||
t.Errorf("Expected 'test', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_Clear(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write some data
|
||||
lm.Write([]byte("hello"))
|
||||
if got := string(lm.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Clear should release the buffer
|
||||
lm.Clear()
|
||||
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil after Clear")
|
||||
}
|
||||
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history after Clear, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_ClearAndReuse(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write, clear, then write again
|
||||
lm.Write([]byte("first"))
|
||||
lm.Clear()
|
||||
lm.Write([]byte("second"))
|
||||
|
||||
if got := string(lm.GetHistory()); got != "second" {
|
||||
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogMonitorWrite(b *testing.B) {
|
||||
// Test data of varying sizes
|
||||
smallMsg := []byte("small message\n")
|
||||
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
|
||||
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
|
||||
|
||||
b.Run("SmallWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(smallMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("MediumWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LargeWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(largeMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WithSubscribers", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Add some subscribers
|
||||
for i := 0; i < 5; i++ {
|
||||
lm.OnLogData(func(data []byte) {})
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetHistory", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Pre-populate with data
|
||||
for i := 0; i < 1000; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.GetHistory()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
Benchmark Results - MBP M1 Pro
|
||||
|
||||
Before (ring.Ring):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 43 ns | 40 B | 2 |
|
||||
| MediumWrite (241B) | 76 ns | 264 B | 2 |
|
||||
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
|
||||
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
|
||||
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
|
||||
|
||||
After (circularBuffer 10KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 67 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
|
||||
|
||||
After (circularBuffer 100KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|-----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 66 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
|
||||
|
||||
Summary:
|
||||
- GetHistory: 139x faster (10KB), 18x faster (100KB)
|
||||
- Allocations: reduced from 2 to 1 across all operations
|
||||
- Small/medium writes: ~1.1-1.6x faster
|
||||
*/
|
||||
|
||||
@@ -2,6 +2,8 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -26,6 +28,28 @@ type TokenMetrics struct {
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// Size returns the approximate memory usage of this capture in bytes
|
||||
func (c *ReqRespCapture) Size() int {
|
||||
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
|
||||
for k, v := range c.ReqHeaders {
|
||||
size += len(k) + len(v)
|
||||
}
|
||||
for k, v := range c.RespHeaders {
|
||||
size += len(k) + len(v)
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
@@ -44,19 +68,32 @@ type metricsMonitor struct {
|
||||
maxMetrics int
|
||||
nextID int
|
||||
logger *LogMonitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captures map[int]ReqRespCapture // map for O(1) lookup by ID
|
||||
captureOrder []int // track insertion order for FIFO eviction
|
||||
captureSize int // current total size in bytes
|
||||
maxCaptureSize int // max bytes for captures
|
||||
}
|
||||
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
|
||||
mp := &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
return &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
captures: make(map[int]ReqRespCapture),
|
||||
captureOrder: make([]int, 0),
|
||||
captureSize: 0,
|
||||
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
// addMetrics adds a new metric to the collection and publishes an event.
|
||||
// Returns the assigned metric ID.
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
@@ -67,6 +104,49 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||
if !mp.enableCaptures {
|
||||
return
|
||||
}
|
||||
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
captureSize := capture.Size()
|
||||
if captureSize > mp.maxCaptureSize {
|
||||
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest (FIFO) until room available
|
||||
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||
oldestID := mp.captureOrder[0]
|
||||
mp.captureOrder = mp.captureOrder[1:]
|
||||
if evicted, exists := mp.captures[oldestID]; exists {
|
||||
mp.captureSize -= evicted.Size()
|
||||
delete(mp.captures, oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
mp.captures[capture.ID] = capture
|
||||
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||
mp.captureSize += captureSize
|
||||
}
|
||||
|
||||
// getCaptureByID returns a capture by its ID, or nil if not found.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
if capture, exists := mp.captures[id]; exists {
|
||||
return &capture
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics
|
||||
@@ -95,7 +175,35 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
request *http.Request,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||
}
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
reqHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -108,30 +216,94 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize default metrics - these will always be recorded
|
||||
tm := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics skipped, empty body")
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsed
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
mp.addMetrics(tm)
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// extract timings for infill - response is an array, timings are in the last element
|
||||
// see #463
|
||||
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsedMetrics
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
respHeaders := make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: body,
|
||||
}
|
||||
// Only set HasCapture if the capture will actually be stored (not too large)
|
||||
if capture.Size() <= mp.maxCaptureSize {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.addMetrics(tm)
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
mp.addCapture(*capture)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -174,19 +346,27 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
|
||||
}
|
||||
|
||||
if gjson.ValidBytes(data) {
|
||||
return parseMetrics(modelID, start, gjson.ParseBytes(data))
|
||||
parsed := gjson.ParseBytes(data)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// v1/responses format nests usage under response.usage
|
||||
if !usage.Exists() {
|
||||
usage = parsed.Get("response.usage")
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
return parseMetrics(modelID, start, usage, timings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) {
|
||||
usage := jsonData.Get("usage")
|
||||
timings := jsonData.Get("timings")
|
||||
if !usage.Exists() && !timings.Exists() {
|
||||
return TokenMetrics{}, fmt.Errorf("no usage or timings data found")
|
||||
}
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
|
||||
// default values
|
||||
cachedTokens := -1 // unknown or missing data
|
||||
outputTokens := 0
|
||||
@@ -195,22 +375,41 @@ func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (Token
|
||||
// timings data
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
durationMs := int(time.Since(start).Milliseconds())
|
||||
durationMs := wallDurationMs
|
||||
|
||||
if usage.Exists() {
|
||||
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
||||
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
||||
if pt := usage.Get("prompt_tokens"); pt.Exists() {
|
||||
// v1/chat/completions
|
||||
inputTokens = int(pt.Int())
|
||||
} else if it := usage.Get("input_tokens"); it.Exists() {
|
||||
// v1/messages
|
||||
inputTokens = int(it.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("completion_tokens"); ct.Exists() {
|
||||
// v1/chat/completions
|
||||
outputTokens = int(ct.Int())
|
||||
} else if ot := usage.Get("output_tokens"); ot.Exists() {
|
||||
outputTokens = int(ot.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
|
||||
cachedTokens = int(ct.Int())
|
||||
}
|
||||
}
|
||||
|
||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||
if timings.Exists() {
|
||||
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
||||
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
||||
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
||||
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||
inputTokens = int(timings.Get("prompt_n").Int())
|
||||
outputTokens = int(timings.Get("predicted_n").Int())
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
if timingsDurationMs > durationMs {
|
||||
durationMs = timingsDurationMs
|
||||
}
|
||||
|
||||
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = int(cachedValue.Int())
|
||||
}
|
||||
}
|
||||
@@ -227,6 +426,25 @@ func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (Token
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
@@ -265,3 +483,43 @@ func (w *responseBodyCopier) Header() http.Header {
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that should be redacted in captures
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -11,11 +14,12 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -34,7 +38,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
mm.addMetrics(TokenMetrics{Model: "model"})
|
||||
@@ -48,7 +52,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("respects max metrics limit", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 3)
|
||||
mm := newMetricsMonitor(testLogger, 3, 0)
|
||||
|
||||
// Add 5 metrics
|
||||
for i := 0; i < 5; i++ {
|
||||
@@ -68,7 +72,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
receivedEvent := make(chan TokenMetricsEvent, 1)
|
||||
cancel := event.On(func(e TokenMetricsEvent) {
|
||||
@@ -98,14 +102,14 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||
t.Run("returns empty slice when no metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
metrics := mm.getMetrics()
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
mm.addMetrics(TokenMetrics{Model: "model1"})
|
||||
mm.addMetrics(TokenMetrics{Model: "model2"})
|
||||
|
||||
@@ -125,7 +129,7 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
jsonData, err := mm.getMetricsJSON()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jsonData)
|
||||
@@ -137,7 +141,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
mm.addMetrics(TokenMetrics{
|
||||
Model: "model1",
|
||||
InputTokens: 100,
|
||||
@@ -165,7 +169,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"usage": {
|
||||
@@ -196,7 +200,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("successful request with timings data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"timings": {
|
||||
@@ -236,7 +240,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("streaming request with SSE format", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Note: SSE format requires proper line breaks - each data line followed by blank line
|
||||
responseBody := `data: {"choices":[{"text":"Hello"}]}
|
||||
@@ -272,7 +276,7 @@ data: [DONE]
|
||||
})
|
||||
|
||||
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -291,8 +295,8 @@ data: [DONE]
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("empty response body does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("empty response body records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -307,11 +311,14 @@ data: [DONE]
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -328,11 +335,14 @@ data: [DONE]
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
expectedErr := assert.AnError
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -350,8 +360,8 @@ data: [DONE]
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("response without usage or timings records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"result": "ok"}`
|
||||
|
||||
@@ -367,10 +377,82 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Infill response is an array with timings in the last element
|
||||
responseBody := `[
|
||||
{"content": "first chunk"},
|
||||
{"content": "second chunk"},
|
||||
{"content": "final", "timings": {
|
||||
"prompt_n": 150,
|
||||
"predicted_n": 75,
|
||||
"prompt_per_second": 200.5,
|
||||
"predicted_per_second": 35.5,
|
||||
"prompt_ms": 600.0,
|
||||
"predicted_ms": 1800.0,
|
||||
"cache_n": 30
|
||||
}}
|
||||
]`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/infill", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 150, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
assert.Equal(t, 30, metrics[0].CachedTokens)
|
||||
assert.Equal(t, 200.5, metrics[0].PromptPerSecond)
|
||||
assert.Equal(t, 35.5, metrics[0].TokensPerSecond)
|
||||
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
|
||||
})
|
||||
|
||||
t.Run("infill request with empty array records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `[]`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/infill", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -425,7 +507,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 1000)
|
||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
@@ -452,7 +534,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 100)
|
||||
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
@@ -489,8 +571,29 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
t.Run("keeps wall clock duration when timings underreport request time", func(t *testing.T) {
|
||||
start := time.Now().Add(-5 * time.Second)
|
||||
usage := gjson.Parse(`{"prompt_tokens": 5, "completion_tokens": 1}`)
|
||||
timings := gjson.Parse(`{
|
||||
"prompt_n": 5,
|
||||
"predicted_n": 1,
|
||||
"prompt_per_second": 10.0,
|
||||
"predicted_per_second": 2.0,
|
||||
"prompt_ms": 5.0,
|
||||
"predicted_ms": 15.0
|
||||
}`)
|
||||
|
||||
metrics, err := parseMetrics("test-model", start, usage, timings)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, metrics.InputTokens)
|
||||
assert.Equal(t, 1, metrics.OutputTokens)
|
||||
assert.Equal(t, 10.0, metrics.PromptPerSecond)
|
||||
assert.Equal(t, 2.0, metrics.TokensPerSecond)
|
||||
assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
|
||||
})
|
||||
|
||||
t.Run("prefers timings over usage data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Timings should take precedence over usage
|
||||
responseBody := `{
|
||||
@@ -530,7 +633,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{
|
||||
"timings": {
|
||||
@@ -565,7 +668,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||
|
||||
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
|
||||
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Metrics should be found in the last data line before [DONE]
|
||||
responseBody := `data: {"choices":[{"text":"First"}]}
|
||||
@@ -598,8 +701,8 @@ data: [DONE]
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `data: not json
|
||||
|
||||
@@ -619,14 +722,46 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles empty streaming response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// v1/responses SSE format: usage is nested under response.usage
|
||||
responseBody := "event: response.completed\n" +
|
||||
`data: {"type":"response.completed","response":{"id":"resp_abc","object":"response","created_at":1773416985,"status":"completed","model":"test-model","output":[],"usage":{"input_tokens":17,"output_tokens":23,"total_tokens":40}}}` +
|
||||
"\n\n"
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 17, metrics[0].InputTokens)
|
||||
assert.Equal(t, 23, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := ``
|
||||
|
||||
@@ -642,17 +777,19 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
// Empty body should not trigger WrapHandler processing
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||
mm := newMetricsMonitor(testLogger, 1000)
|
||||
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -673,7 +810,7 @@ func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||
|
||||
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||
mm := newMetricsMonitor(testLogger, 100)
|
||||
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||
|
||||
metric := TokenMetrics{
|
||||
Model: "test-model",
|
||||
@@ -691,3 +828,352 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
mm.addMetrics(metric)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||
t.Run("gzip encoded response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
// Compress with gzip
|
||||
var buf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&buf)
|
||||
gzWriter.Write([]byte(responseBody))
|
||||
gzWriter.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(compressedBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("deflate encoded response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}`
|
||||
|
||||
// Compress with deflate
|
||||
var buf bytes.Buffer
|
||||
flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression)
|
||||
flateWriter.Write([]byte(responseBody))
|
||||
flateWriter.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "deflate")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(compressedBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 200, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
// Invalid compressed data
|
||||
invalidData := []byte("this is not gzip data")
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(invalidData)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Should not return error, just log warning
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "unknown-encoding")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 300, metrics[0].InputTokens)
|
||||
assert.Equal(t, 100, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReqRespCapture_Size(t *testing.T) {
|
||||
t.Run("calculates size correctly", func(t *testing.T) {
|
||||
capture := ReqRespCapture{
|
||||
ID: 1,
|
||||
ReqPath: "/v1/chat/completions", // 20 bytes
|
||||
ReqHeaders: map[string]string{
|
||||
"Content-Type": "application/json", // 12 + 16 = 28
|
||||
},
|
||||
ReqBody: []byte("request body"), // 12 bytes
|
||||
RespHeaders: map[string]string{
|
||||
"X-Test": "value", // 6 + 5 = 11
|
||||
},
|
||||
RespBody: []byte("response body"), // 13 bytes
|
||||
}
|
||||
|
||||
// Expected: 20 + 12 + 13 + 28 + 11 = 84
|
||||
assert.Equal(t, 84, capture.Size())
|
||||
})
|
||||
|
||||
t.Run("handles empty capture", func(t *testing.T) {
|
||||
capture := ReqRespCapture{}
|
||||
assert.Equal(t, 0, capture.Size())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_AddCapture(t *testing.T) {
|
||||
t.Run("does nothing when captures disabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 0,
|
||||
ReqBody: []byte("test"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
// Should not store capture
|
||||
assert.Nil(t, mm.getCaptureByID(0))
|
||||
})
|
||||
|
||||
t.Run("adds capture when enabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 0,
|
||||
ReqBody: []byte("test request"),
|
||||
RespBody: []byte("test response"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
retrieved := mm.getCaptureByID(0)
|
||||
assert.NotNil(t, retrieved)
|
||||
assert.Equal(t, 0, retrieved.ID)
|
||||
assert.Equal(t, []byte("test request"), retrieved.ReqBody)
|
||||
assert.Equal(t, []byte("test response"), retrieved.RespBody)
|
||||
})
|
||||
|
||||
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
mm.maxCaptureSize = 100 // Set small limit for test
|
||||
|
||||
// Add captures that will exceed the limit
|
||||
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)}
|
||||
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)}
|
||||
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)}
|
||||
|
||||
mm.addCapture(capture1)
|
||||
mm.addCapture(capture2)
|
||||
// Adding capture3 should evict capture1
|
||||
mm.addCapture(capture3)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
|
||||
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
|
||||
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
|
||||
})
|
||||
|
||||
t.Run("skips capture larger than max size", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
mm.maxCaptureSize = 100
|
||||
|
||||
// Add a capture larger than max
|
||||
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)}
|
||||
mm.addCapture(largeCapture)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
||||
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
assert.Nil(t, mm.getCaptureByID(999))
|
||||
})
|
||||
|
||||
t.Run("returns capture by ID", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
capture := ReqRespCapture{
|
||||
ID: 42,
|
||||
ReqBody: []byte("test"),
|
||||
}
|
||||
mm.addCapture(capture)
|
||||
|
||||
retrieved := mm.getCaptureByID(42)
|
||||
assert.NotNil(t, retrieved)
|
||||
assert.Equal(t, 42, retrieved.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
t.Run("redacts sensitive headers", func(t *testing.T) {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer secret-token",
|
||||
"Proxy-Authorization": "Basic creds",
|
||||
"Cookie": "session=abc123",
|
||||
"Set-Cookie": "session=xyz789",
|
||||
"X-Api-Key": "sk-12345",
|
||||
"Content-Type": "application/json",
|
||||
"X-Custom": "safe-value",
|
||||
}
|
||||
|
||||
redactHeaders(headers)
|
||||
|
||||
assert.Equal(t, "[REDACTED]", headers["Authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Proxy-Authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Cookie"])
|
||||
assert.Equal(t, "[REDACTED]", headers["Set-Cookie"])
|
||||
assert.Equal(t, "[REDACTED]", headers["X-Api-Key"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
assert.Equal(t, "safe-value", headers["X-Custom"])
|
||||
})
|
||||
|
||||
t.Run("handles mixed case header names", func(t *testing.T) {
|
||||
headers := map[string]string{
|
||||
"authorization": "Bearer token",
|
||||
"COOKIE": "session=abc",
|
||||
"x-api-key": "key123",
|
||||
}
|
||||
|
||||
redactHeaders(headers)
|
||||
|
||||
assert.Equal(t, "[REDACTED]", headers["authorization"])
|
||||
assert.Equal(t, "[REDACTED]", headers["COOKIE"])
|
||||
assert.Equal(t, "[REDACTED]", headers["x-api-key"])
|
||||
})
|
||||
|
||||
t.Run("handles empty headers", func(t *testing.T) {
|
||||
headers := map[string]string{}
|
||||
redactHeaders(headers)
|
||||
assert.Empty(t, headers)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
||||
t.Run("captures request and response when enabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||
|
||||
requestBody := `{"model": "test", "prompt": "hello"}`
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Custom", "header-value")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer secret")
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check metric was recorded
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
metricID := metrics[0].ID
|
||||
|
||||
// Check capture was stored with same ID
|
||||
capture := mm.getCaptureByID(metricID)
|
||||
assert.NotNil(t, capture)
|
||||
assert.Equal(t, metricID, capture.ID)
|
||||
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||
assert.Equal(t, "/test", capture.ReqPath)
|
||||
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
|
||||
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
|
||||
})
|
||||
|
||||
t.Run("does not capture when disabled", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||
|
||||
requestBody := `{"model": "test"}`
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Metrics should still be recorded
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
|
||||
// But no capture
|
||||
capture := mm.getCaptureByID(metrics[0].ID)
|
||||
assert.Nil(t, capture)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
|
||||
// Create a transport with per-peer timeout configuration
|
||||
peerTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"test-peer": config.PeerConfig{
|
||||
Proxy: "http://localhost:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"model1"},
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 300,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(peers, testLogger)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, peerProxy)
|
||||
assert.True(t, peerProxy.HasPeerModel("model1"))
|
||||
|
||||
// Verify the timeout values are actually applied to the transport
|
||||
member, found := peerProxy.proxyMap["model1"]
|
||||
require.True(t, found, "model1 should exist in proxyMap")
|
||||
assert.NotNil(t, member.reverseProxy)
|
||||
assert.NotNil(t, member.reverseProxy.Transport)
|
||||
|
||||
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||
require.True(t, ok, "Transport should be *http.Transport")
|
||||
|
||||
// Verify all timeout values are correctly applied
|
||||
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
// ForceAttemptHTTP2 should be enabled
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
@@ -96,6 +96,24 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
|
||||
var reverseProxy *httputil.ReverseProxy
|
||||
if proxyURL != nil {
|
||||
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||
|
||||
// Create custom transport with configured timeouts
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
reverseProxy.Transport = transport
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
@@ -256,6 +274,7 @@ func (p *Process) start() error {
|
||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||
setProcAttributes(p.cmd)
|
||||
|
||||
p.cmdMutex.Lock()
|
||||
p.cancelUpstream = ctxCancelUpstream
|
||||
@@ -413,6 +432,9 @@ func (p *Process) stopCommand() {
|
||||
stopStartTime := time.Now()
|
||||
defer func() {
|
||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||
|
||||
// free the buffer in processLogger so the memory can be recovered
|
||||
p.processLogger.Clear()
|
||||
}()
|
||||
|
||||
p.cmdMutex.RLock()
|
||||
@@ -506,7 +528,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// add a sync so the streaming client only runs when the goroutine has exited
|
||||
|
||||
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming {
|
||||
|
||||
// PR #417 (no support for anthropic v1/messages yet)
|
||||
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
|
||||
srw = newStatusResponseWriter(p, w)
|
||||
go srw.statusUpdates(swapCtx)
|
||||
} else {
|
||||
@@ -625,6 +650,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
setProcAttributes(stopCmd)
|
||||
stopCmd.Env = p.cmd.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
@@ -641,6 +667,11 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger returns the logger for this process.
|
||||
func (p *Process) Logger() *LogMonitor {
|
||||
return p.processLogger
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
@@ -859,7 +890,6 @@ func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
// Add Flush method
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -117,12 +118,12 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
}
|
||||
|
||||
expectedMessage := "I_sense_imminent_danger"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -159,12 +160,12 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
config := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
conf := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
@@ -395,6 +396,10 @@ func TestProcess_StopImmediately(t *testing.T) {
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
@@ -565,3 +570,39 @@ func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||
}
|
||||
return w.ResponseRecorder.Write(b)
|
||||
}
|
||||
|
||||
func TestProcess_CustomTimeouts(t *testing.T) {
|
||||
modelConfig := config.ModelConfig{
|
||||
Cmd: "echo test",
|
||||
Proxy: "http://localhost:8080",
|
||||
CheckEndpoint: "/health",
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 120,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
}
|
||||
|
||||
debugLogger := NewLogMonitorWriter(io.Discard)
|
||||
process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger)
|
||||
|
||||
// Verify the process was created successfully
|
||||
assert.NotNil(t, process)
|
||||
assert.Equal(t, "test-model", process.ID)
|
||||
assert.NotNil(t, process.reverseProxy)
|
||||
assert.NotNil(t, process.reverseProxy.Transport)
|
||||
|
||||
// Verify it's using http.Transport (not some other type)
|
||||
transport, ok := process.reverseProxy.Transport.(*http.Transport)
|
||||
assert.True(t, ok, "Transport should be *http.Transport")
|
||||
assert.NotNil(t, transport)
|
||||
|
||||
// Verify the timeouts are correctly applied
|
||||
assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
@@ -46,7 +46,8 @@ func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, u
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
||||
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
@@ -88,6 +89,13 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||
if pg.HasMember(modelName) {
|
||||
return pg.processes[modelName], true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
|
||||
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
@@ -27,6 +28,40 @@ const (
|
||||
|
||||
type proxyCtxKey string
|
||||
|
||||
type InflightCounter struct {
|
||||
mu sync.Mutex
|
||||
total int
|
||||
}
|
||||
|
||||
func newInflightCounter() *InflightCounter {
|
||||
return &InflightCounter{}
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Current() int {
|
||||
ic.mu.Lock()
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Increment() int {
|
||||
ic.mu.Lock()
|
||||
ic.total++
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
func (ic *InflightCounter) Decrement() int {
|
||||
ic.mu.Lock()
|
||||
if ic.total > 0 {
|
||||
ic.total--
|
||||
}
|
||||
total := ic.total
|
||||
ic.mu.Unlock()
|
||||
return total
|
||||
}
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
@@ -42,6 +77,8 @@ type ProxyManager struct {
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
|
||||
inFlightCounter *InflightCounter
|
||||
|
||||
// shutdown signaling
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
@@ -50,19 +87,42 @@ type ProxyManager struct {
|
||||
buildDate string
|
||||
commit string
|
||||
version string
|
||||
|
||||
// peer proxy see: #296, #433
|
||||
peerProxy *PeerProxy
|
||||
}
|
||||
|
||||
func New(config config.Config) *ProxyManager {
|
||||
func New(proxyConfig config.Config) *ProxyManager {
|
||||
// set up loggers
|
||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
||||
|
||||
if config.LogRequests {
|
||||
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
|
||||
switch proxyConfig.LogToStdout {
|
||||
case config.LogToStdoutNone:
|
||||
muxLogger = NewLogMonitorWriter(io.Discard)
|
||||
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||
case config.LogToStdoutBoth:
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||
case config.LogToStdoutUpstream:
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||
default:
|
||||
// same as config.LogToStdoutProxy
|
||||
// helpful because some old tests create a config.Config directly and it
|
||||
// may not have LogToStdout set explicitly
|
||||
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||
}
|
||||
|
||||
if proxyConfig.LogRequests {
|
||||
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
||||
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
|
||||
case "debug":
|
||||
proxyLogger.SetLogLevel(LevelDebug)
|
||||
upstreamLogger.SetLogLevel(LevelDebug)
|
||||
@@ -99,7 +159,7 @@ func New(config config.Config) *ProxyManager {
|
||||
"stampnano": time.StampNano,
|
||||
}
|
||||
|
||||
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(config.LogTimeFormat))]; ok {
|
||||
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
|
||||
proxyLogger.SetLogTimeFormat(timeFormat)
|
||||
upstreamLogger.SetLogTimeFormat(timeFormat)
|
||||
}
|
||||
@@ -107,61 +167,78 @@ func New(config config.Config) *ProxyManager {
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
|
||||
var maxMetrics int
|
||||
if config.MetricsMaxInMemory <= 0 {
|
||||
if proxyConfig.MetricsMaxInMemory <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
} else {
|
||||
maxMetrics = config.MetricsMaxInMemory
|
||||
maxMetrics = proxyConfig.MetricsMaxInMemory
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
||||
peerProxy = nil
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: config,
|
||||
config: proxyConfig,
|
||||
ginEngine: gin.New(),
|
||||
|
||||
proxyLogger: proxyLogger,
|
||||
muxLogger: stdoutLogger,
|
||||
muxLogger: muxLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
|
||||
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer),
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
|
||||
inFlightCounter: newInflightCounter(),
|
||||
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
|
||||
buildDate: "unknown",
|
||||
commit: "abcd1234",
|
||||
version: "0",
|
||||
|
||||
peerProxy: peerProxy,
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
for groupID := range config.Groups {
|
||||
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
||||
for groupID := range proxyConfig.Groups {
|
||||
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
||||
pm.processGroups[groupID] = processGroup
|
||||
}
|
||||
|
||||
pm.setupGinEngine()
|
||||
|
||||
// run any startup hooks
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
|
||||
// do it in the background, don't block startup -- not sure if good idea yet
|
||||
go func() {
|
||||
discardWriter := &DiscardWriter{}
|
||||
for _, realModelName := range config.Hooks.OnStartup.Preload {
|
||||
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
||||
|
||||
if !ok {
|
||||
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
|
||||
if err != nil {
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: false,
|
||||
})
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
|
||||
continue
|
||||
} else {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||
processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: true,
|
||||
})
|
||||
}
|
||||
@@ -236,35 +313,50 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
})
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
// Protected routes use pm.apiKeyAuth() middleware
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
// Support anthropic count_tokens API (Also added in the above PR)
|
||||
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// Support embeddings and reranking
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /reranking endpoint + aliases
|
||||
pm.ginEngine.POST("/reranking", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /infill endpoint for code infilling
|
||||
pm.ginEngine.POST("/infill", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /completion endpoint
|
||||
pm.ginEngine.POST("/completion", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
// sd.cpp /sdapi/v1 endpoints
|
||||
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||
|
||||
// in proxymanager_loghandlers.go
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||
|
||||
/**
|
||||
* User Interface Endpoints
|
||||
@@ -276,9 +368,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
@@ -300,25 +392,35 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
if err != nil {
|
||||
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
||||
} else {
|
||||
// Serve files with compression support under /ui/*
|
||||
// This handler checks for pre-compressed .br and .gz files
|
||||
pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) {
|
||||
filepath := strings.TrimPrefix(c.Param("filepath"), "/")
|
||||
// Default to index.html for directory-like paths
|
||||
if filepath == "" {
|
||||
filepath = "index.html"
|
||||
}
|
||||
|
||||
// serve files that exist under /ui/*
|
||||
pm.ginEngine.StaticFS("/ui", reactFS)
|
||||
ServeCompressedFile(reactFS, c.Writer, c.Request, filepath)
|
||||
})
|
||||
|
||||
// server SPA for UI under /ui/*
|
||||
// Serve SPA for UI under /ui/* - fallback to index.html for client-side routing
|
||||
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
||||
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
file, err := reactFS.Open("index.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
// Check if this looks like a file request (has extension)
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") {
|
||||
// This was likely a file request that wasn't found
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
|
||||
|
||||
// Serve index.html for SPA routing
|
||||
ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -330,6 +432,14 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
gin.DisableConsoleColor()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) trackInflight() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Increment()})
|
||||
defer event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Decrement()})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler interface
|
||||
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
pm.ginEngine.ServeHTTP(w, r)
|
||||
@@ -376,16 +486,10 @@ func (pm *ProxyManager) Shutdown() {
|
||||
pm.shutdownCancel()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
||||
}
|
||||
|
||||
if processGroup.exclusive {
|
||||
@@ -397,54 +501,71 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
||||
}
|
||||
}
|
||||
|
||||
return processGroup, realModelName, nil
|
||||
return processGroup, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := make([]gin.H, 0, len(pm.config.Models))
|
||||
createdTime := time.Now().Unix()
|
||||
|
||||
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
newRecord := func(modelId string) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
data = append(data, newRecord(id))
|
||||
data = append(data, newRecord(id, modelConfig))
|
||||
|
||||
// Include aliases
|
||||
if pm.config.IncludeAliasesInList {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias))
|
||||
data = append(data, newRecord(alias, modelConfig))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
// add peer models
|
||||
for _, modelID := range peer.Models {
|
||||
// Skip unlisted models if not showing them
|
||||
record := newRecord(modelID, config.ModelConfig{
|
||||
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
||||
Metadata: map[string]any{
|
||||
"peerID": peerID,
|
||||
},
|
||||
})
|
||||
|
||||
data = append(data, record)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by the "id" key
|
||||
sort.Slice(data, func(i, j int) bool {
|
||||
si, _ := data[i]["id"].(string)
|
||||
@@ -464,62 +585,61 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
// split the upstream path by / and search for the model name
|
||||
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
|
||||
if len(parts) == 0 {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
modelFound := false
|
||||
// findModelInPath searches for a valid model name in a path with slashes.
|
||||
// It iteratively builds up path segments until it finds a matching model.
|
||||
// Returns: (searchModelName, realModelName, remainingPath, found)
|
||||
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
|
||||
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
|
||||
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||
searchModelName := ""
|
||||
var modelName, remainingPath string
|
||||
|
||||
for i, part := range parts {
|
||||
if parts[i] == "" {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if searchModelName == "" {
|
||||
searchModelName = part
|
||||
} else {
|
||||
searchModelName = searchModelName + "/" + parts[i]
|
||||
searchModelName = searchModelName + "/" + part
|
||||
}
|
||||
|
||||
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
modelName = real
|
||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||
modelFound = true
|
||||
|
||||
// Check if this is exactly a model name with no additional path
|
||||
// and doesn't end with a trailing slash
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
// Build new URL with query parameters preserved
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Use 308 for non-GET/HEAD requests to preserve method
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
break
|
||||
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||
|
||||
if !modelFound {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
||||
// This ensures relative URLs in upstream responses resolve correctly and
|
||||
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
||||
// HTTP method (301 would downgrade to GET).
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -531,21 +651,21 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
|
||||
// attempt to record metrics if it is a POST request
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
@@ -558,41 +678,101 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
// Look for a matching local model first
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// issue #453 set/override parameters in the JSON body
|
||||
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
||||
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
||||
for _, key := range setParamsByIDKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
|
||||
// issue #453 apply filters for peer requests
|
||||
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||
|
||||
// Apply stripParams - remove specified parameters from request
|
||||
stripParams := peerFilters.SanitizedStripParams()
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Apply setParams - set/override specified parameters in request
|
||||
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
@@ -605,19 +785,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -637,9 +817,29 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
// Look for a matching local model first, then check peers
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
var useModelName string
|
||||
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
useModelName = pm.config.Models[modelID].UseModelName
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -655,8 +855,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
} else {
|
||||
@@ -726,9 +924,46 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
||||
requestedModel := c.Query("model")
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter")
|
||||
return
|
||||
}
|
||||
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
var modelID string
|
||||
|
||||
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
||||
processGroup, err := pm.swapProcessGroup(realModelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
modelID = realModelID
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
modelID = requestedModel
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -743,6 +978,67 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
||||
}
|
||||
}
|
||||
|
||||
// apiKeyAuth returns a middleware that validates API keys if configured.
|
||||
// Returns a pass-through handler if no API keys are configured.
|
||||
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
||||
if len(pm.config.RequiredAPIKeys) == 0 {
|
||||
return func(c *gin.Context) { c.Next() }
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
xApiKey := c.GetHeader("x-api-key")
|
||||
|
||||
var bearerKey string
|
||||
var basicKey string
|
||||
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||
} else if strings.HasPrefix(auth, "Basic ") {
|
||||
// Basic Auth: base64(username:password), password is the API key
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) == 2 {
|
||||
basicKey = parts[1] // password is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use first key found: Basic, then Bearer, then x-api-key
|
||||
var providedKey string
|
||||
if basicKey != "" {
|
||||
providedKey = basicKey
|
||||
} else if bearerKey != "" {
|
||||
providedKey = bearerKey
|
||||
} else {
|
||||
providedKey = xApiKey
|
||||
}
|
||||
|
||||
// Validate key
|
||||
valid := false
|
||||
for _, key := range pm.config.RequiredAPIKeys {
|
||||
if providedKey == key {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Strip auth headers to prevent leakage to upstream
|
||||
c.Request.Header.Del("Authorization")
|
||||
c.Request.Header.Del("x-api-key")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.String(http.StatusOK, "OK")
|
||||
@@ -756,8 +1052,13 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"cmd": process.config.Cmd,
|
||||
"proxy": process.config.Proxy,
|
||||
"ttl": process.config.UnloadAfter,
|
||||
"name": process.config.Name,
|
||||
"description": process.config.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -13,22 +14,26 @@ import (
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
// Protected with API key authentication
|
||||
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
apiGroup.GET("/version", pm.apiGetVersion)
|
||||
apiGroup.GET("/captures/:id", pm.apiGetCapture)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,9 +84,22 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
||||
Description: pm.config.Models[modelID].Description,
|
||||
State: state,
|
||||
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||
Aliases: pm.config.Models[modelID].Aliases,
|
||||
})
|
||||
}
|
||||
|
||||
// Iterate over the peer models
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
PeerID: peerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -91,6 +109,7 @@ const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
msgTypeInFlight messageType = "inflight"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
@@ -150,6 +169,18 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
sendInFlight := func(total int) {
|
||||
jsonData, err := json.Marshal(gin.H{"total": total})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
@@ -177,11 +208,19 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
sendMetrics([]TokenMetrics{e.Metrics})
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send in-flight request stats related to token stats "Waiting: N" count.
|
||||
*/
|
||||
defer event.On(func(e InFlightRequestsEvent) {
|
||||
sendInFlight(e.Total)
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||
sendInFlight(pm.inFlightCounter.Current())
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -236,3 +275,20 @@ func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||
"build_date": pm.buildDate,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
|
||||
return
|
||||
}
|
||||
|
||||
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, capture)
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
// prevent nginx from buffering streamed logs
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
logMonitorId := c.Param("logMonitorID")
|
||||
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
@@ -83,18 +83,25 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
|
||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||
var logger *LogMonitor
|
||||
|
||||
if logMonitorId == "" {
|
||||
switch logMonitorId {
|
||||
case "":
|
||||
// maintain the default
|
||||
logger = pm.muxLogger
|
||||
} else if logMonitorId == "proxy" {
|
||||
logger = pm.proxyLogger
|
||||
} else if logMonitorId == "upstream" {
|
||||
logger = pm.upstreamLogger
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
||||
}
|
||||
return pm.muxLogger, nil
|
||||
case "proxy":
|
||||
return pm.proxyLogger, nil
|
||||
case "upstream":
|
||||
return pm.upstreamLogger, nil
|
||||
default:
|
||||
// search for a models specific logger using findModelInPath
|
||||
// to handle model names with slashes (e.g., "author/model")
|
||||
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
|
||||
for _, group := range pm.processGroups {
|
||||
if process, found := group.GetMember(name); found {
|
||||
return process.Logger(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
||||
return r.closeChannel
|
||||
}
|
||||
|
||||
func (r *TestResponseRecorder) closeClient() {
|
||||
r.closeChannel <- true
|
||||
}
|
||||
|
||||
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||
return &TestResponseRecorder{
|
||||
httptest.NewRecorder(),
|
||||
@@ -223,17 +220,23 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||
model2Config.Description = " "
|
||||
|
||||
config := config.Config{
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
Peers: map[string]config.PeerConfig{
|
||||
"peer1": {
|
||||
Proxy: "http://peer1:8080",
|
||||
Models: []string{"peer-model-a", "peer-model-b"},
|
||||
},
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
proxy := New(cfg)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
@@ -258,14 +261,16 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Check the number of models returned
|
||||
assert.Len(t, response.Data, 3)
|
||||
// Check the number of models returned (3 local + 2 peer models)
|
||||
assert.Len(t, response.Data, 5)
|
||||
|
||||
// Check the details of each model
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
"peer-model-a": {},
|
||||
"peer-model-b": {},
|
||||
}
|
||||
|
||||
// make all models
|
||||
@@ -296,6 +301,19 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
description, ok := model["description"].(string)
|
||||
assert.True(t, ok, "description should be a string")
|
||||
assert.Equal(t, "Model 1 description is used for testing", description)
|
||||
} else if modelID == "peer-model-a" || modelID == "peer-model-b" {
|
||||
// Peer models should have meta.llamaswap.peerID
|
||||
meta, exists := model["meta"]
|
||||
assert.True(t, exists, "peer model should have meta field")
|
||||
metaMap, ok := meta.(map[string]interface{})
|
||||
assert.True(t, ok, "meta should be a map")
|
||||
llamaswap, exists := metaMap["llamaswap"]
|
||||
assert.True(t, exists, "meta should have llamaswap field")
|
||||
llamaswapMap, ok := llamaswap.(map[string]interface{})
|
||||
assert.True(t, ok, "llamaswap should be a map")
|
||||
peerID, exists := llamaswapMap["peerID"]
|
||||
assert.True(t, exists, "llamaswap should have peerID field")
|
||||
assert.Equal(t, "peer1", peerID)
|
||||
} else {
|
||||
_, exists := model["name"]
|
||||
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
||||
@@ -502,6 +520,10 @@ func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
model1Config.Proxy = "http://localhost:10001/"
|
||||
@@ -650,8 +672,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
Running []struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Cmd string `json:"cmd"`
|
||||
Proxy string `json:"proxy"`
|
||||
TTL int `json:"ttl"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"running"`
|
||||
}
|
||||
|
||||
@@ -699,6 +726,11 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
|
||||
// Verify extended fields are present
|
||||
assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated")
|
||||
assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated")
|
||||
assert.Equal(t, -1, response.Running[0].TTL, "ttl should default to -1 (use globalTTL)")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -818,6 +850,43 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_AudioVoicesGETHandler(t *testing.T) {
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("successful GET with model query param", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices?model=model1", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "voice1")
|
||||
})
|
||||
|
||||
t.Run("missing model query param returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/v1/audio/voices?model=nonexistent", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -944,7 +1013,9 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig("model1")
|
||||
modelConfig.Filters = config.ModelFilters{
|
||||
StripParams: "temperature, model, stream",
|
||||
Filters: config.Filters{
|
||||
StripParams: "temperature, model, stream",
|
||||
},
|
||||
}
|
||||
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
@@ -975,6 +1046,61 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
// t.Logf("%v", response)
|
||||
}
|
||||
|
||||
func TestProxyManager_FiltersSetParamsByID(t *testing.T) {
|
||||
// no explicit aliases — setParamsByID keys are auto-registered as aliases
|
||||
configStr := strings.Replace(`
|
||||
logLevel: error
|
||||
models:
|
||||
model1:
|
||||
cmd: 'SRPATH --port ${PORT} --silent --respond model1'
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
filters:
|
||||
setParams:
|
||||
reasoning_effort: medium
|
||||
setParamsByID:
|
||||
"${MODEL_ID}:high":
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
reasoning_effort: low
|
||||
`, "SRPATH", simpleResponderPath, -1)
|
||||
|
||||
cfg, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
if !assert.NoError(t, err, "invalid test configuration") {
|
||||
return
|
||||
}
|
||||
|
||||
proxy := New(cfg)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []struct {
|
||||
requestedModel string
|
||||
wantEffort string
|
||||
}{
|
||||
// setParams applies, no setParamsByID match
|
||||
{requestedModel: "model1", wantEffort: "medium"},
|
||||
// setParamsByID overrides setParams
|
||||
{requestedModel: "model1:high", wantEffort: "high"},
|
||||
{requestedModel: "model1:low", wantEffort: "low"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.requestedModel, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf(`{"model":%q}`, tt.requestedModel)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
requestBody, _ := response["request_body"].(string)
|
||||
gotEffort := gjson.Get(requestBody, "reasoning_effort").String()
|
||||
assert.Equal(t, tt.wantEffort, gotEffort, "reasoning_effort mismatch for model %s", tt.requestedModel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
@@ -1078,7 +1204,8 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"author/model": getTestSimpleResponderConfig("author/model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
@@ -1091,6 +1218,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||
"/logs/stream",
|
||||
"/logs/stream/proxy",
|
||||
"/logs/stream/upstream",
|
||||
"/logs/stream/author/model",
|
||||
}
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
@@ -1185,3 +1313,428 @@ func TestProxyManager_ApiGetVersion(t *testing.T) {
|
||||
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_APIKeyAuth(t *testing.T) {
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
t.Run("valid key in x-api-key header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("valid key in Authorization Bearer header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer valid-key-2")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("both headers with matching keys", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
req.Header.Set("Authorization", "Bearer valid-key-1")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid key returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "invalid-key")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||
})
|
||||
|
||||
t.Run("missing key returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
})
|
||||
|
||||
t.Run("valid key in Basic Auth header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
// Basic Auth: base64("anyuser:valid-key-1")
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||
})
|
||||
|
||||
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
|
||||
// Config without RequiredAPIKeys - auth should be disabled
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
t.Run("requests pass without API key when not configured", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration
|
||||
// in proxyInferenceHandler for issue #433
|
||||
func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) {
|
||||
t.Run("requests to peer models are proxied", func(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
// Create config with peers but no local model for "peer-model"
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "from-peer")
|
||||
})
|
||||
|
||||
t.Run("local models take precedence over peer models", func(t *testing.T) {
|
||||
// Create a test server to act as the peer - should NOT be called
|
||||
peerCalled := false
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
peerCalled = true
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"from-peer"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
// Create config where "shared-model" exists both locally and on peer
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- shared-model
|
||||
models:
|
||||
shared-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-response
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"shared-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "local-response")
|
||||
assert.False(t, peerCalled, "peer should not be called when local model exists")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns error", func(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"unknown-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
||||
})
|
||||
|
||||
t.Run("peer API key is injected into request", func(t *testing.T) {
|
||||
var receivedAuthHeader string
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"ok"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
apiKey: secret-peer-key
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader)
|
||||
})
|
||||
|
||||
t.Run("no peers configured - unknown model returns error", func(t *testing.T) {
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"local-model": getTestSimpleResponderConfig("local-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
// peerProxy exists but has no peer models configured
|
||||
assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model"))
|
||||
|
||||
reqBody := `{"model":"unknown-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
||||
})
|
||||
|
||||
t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) {
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("data: test\n\n"))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_SdApiTxt2ImgRouting(t *testing.T) {
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"sd-model": getTestSimpleResponderConfig("sd-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("successful txt2img with model", func(t *testing.T) {
|
||||
reqBody := `{"model":"sd-model","prompt":"a cat"}`
|
||||
req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "sd-model")
|
||||
})
|
||||
|
||||
t.Run("successful img2img with model", func(t *testing.T) {
|
||||
reqBody := `{"model":"sd-model","prompt":"a cat","init_images":[]}`
|
||||
req := httptest.NewRequest("POST", "/sdapi/v1/img2img", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "sd-model")
|
||||
})
|
||||
|
||||
t.Run("missing model returns 400", func(t *testing.T) {
|
||||
reqBody := `{"prompt":"a cat"}`
|
||||
req := httptest.NewRequest("POST", "/sdapi/v1/txt2img", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "missing or invalid 'model' key")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_SdApiGetLoras(t *testing.T) {
|
||||
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"sd-model": getTestSimpleResponderConfig("sd-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(conf)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
t.Run("successful GET loras with model query param", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=sd-model", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("missing model query param returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/sdapi/v1/loras", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "missing required 'model' query parameter")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns 400", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/sdapi/v1/loras?model=nonexistent", nil)
|
||||
w := CreateTestResponseRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable handler")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// selectEncoding chooses the best encoding based on Accept-Encoding header
|
||||
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
|
||||
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||
if acceptEncoding == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "br" {
|
||||
return "br", ".br"
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "gzip" {
|
||||
return "gzip", ".gz"
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// ServeCompressedFile serves a file with compression support.
|
||||
// It checks for pre-compressed versions and serves them with proper headers.
|
||||
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
|
||||
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
|
||||
|
||||
// Try to serve compressed version if client supports it
|
||||
if encoding != "" {
|
||||
if cf, err := fs.Open(name + ext); err == nil {
|
||||
defer cf.Close()
|
||||
|
||||
// Verify it's a regular file (not a directory)
|
||||
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||
// Set the content encoding header
|
||||
w.Header().Set("Content-Encoding", encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
|
||||
// Get original file info for content type detection
|
||||
origFile, err := fs.Open(name)
|
||||
if err == nil {
|
||||
origFile.Close()
|
||||
}
|
||||
|
||||
// Serve the compressed file
|
||||
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serving the uncompressed file
|
||||
file, err := fs.Open(name)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
http.Error(w, "is a directory", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServeCompressedFile_Brotli(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with brotli")
|
||||
brContent := []byte("fake-brotli-compressed-data")
|
||||
|
||||
// Create a test filesystem
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: brContent, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check that brotli is used (preferred over gzip)
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
|
||||
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, brContent) {
|
||||
t.Errorf("Expected brotli content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_Gzip(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with gzip")
|
||||
gzContent := []byte("fake-gzip-compressed-data")
|
||||
|
||||
// Create a test filesystem without brotli
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, gzContent) {
|
||||
t.Errorf("Expected gzip content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is uncompressed test content")
|
||||
|
||||
// Create a test filesystem without compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should not have Content-Encoding header since we're serving uncompressed
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content")
|
||||
|
||||
// Create a test filesystem with compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
// No Accept-Encoding header
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should serve uncompressed content
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NotFound(t *testing.T) {
|
||||
mapFS := fstest.MapFS{}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "nonexistent.js")
|
||||
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectEncoding(t *testing.T) {
|
||||
tests := []struct {
|
||||
acceptEncoding string
|
||||
wantEncoding string
|
||||
wantExt string
|
||||
}{
|
||||
{"br, gzip", "br", ".br"},
|
||||
{"gzip, deflate", "gzip", ".gz"},
|
||||
{"gzip", "gzip", ".gz"},
|
||||
{"br", "br", ".br"},
|
||||
{"", "", ""},
|
||||
{"deflate", "", ""},
|
||||
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
|
||||
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
|
||||
{"browser", "", ""},
|
||||
{"compress, deflate", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
|
||||
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
|
||||
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
|
||||
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual pre-compressed files from ui_dist
|
||||
func TestServeCompressedFile_RealFiles(t *testing.T) {
|
||||
// Check if ui_dist exists
|
||||
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
|
||||
t.Skip("ui_dist not found, skipping real file test")
|
||||
}
|
||||
|
||||
// Find a .js or .css file that has compressed versions
|
||||
entries, err := os.ReadDir("./ui_dist/assets")
|
||||
if err != nil {
|
||||
t.Skipf("Could not read ui_dist/assets: %v", err)
|
||||
}
|
||||
|
||||
var testFile string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
|
||||
// Check if compressed versions exist
|
||||
base := strings.TrimSuffix(name, ".js")
|
||||
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
|
||||
testFile = "assets/" + name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if testFile == "" {
|
||||
t.Skip("No suitable test file found with compressed versions")
|
||||
}
|
||||
|
||||
fs := http.FS(os.DirFS("./ui_dist"))
|
||||
|
||||
// Test brotli
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
})
|
||||
|
||||
// Test gzip
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
// Verify it's valid gzip
|
||||
reader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid gzip content: %v", err)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Just read to verify it's valid
|
||||
_, err = io.Copy(io.Discard, reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decompress gzip: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
node_modules
|
||||
.vite
|
||||
@@ -0,0 +1 @@
|
||||
legacy-peer-deps=true
|
||||
@@ -10,8 +10,8 @@
|
||||
<link rel="manifest" href="/site.webmanifest" />
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body >
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.ts"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"name": "ui-svelte",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"start": "vite",
|
||||
"build": "vite build --emptyOutDir",
|
||||
"preview": "vite preview",
|
||||
"check": "svelte-check --tsconfig ./tsconfig.json",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@sveltejs/vite-plugin-svelte": "^7.0.0",
|
||||
"@tailwindcss/vite": "^4.1.8",
|
||||
"@tsconfig/svelte": "^5.0.4",
|
||||
"@types/hast": "^3.0.4",
|
||||
"@types/node": "^25.1.0",
|
||||
"svelte": "^5.46.4",
|
||||
"svelte-check": "^4.1.4",
|
||||
"tailwindcss": "^4.1.8",
|
||||
"typescript": "~5.8.3",
|
||||
"vite": "^8.0.0",
|
||||
"vite-plugin-compression2": "^2.5.1",
|
||||
"vitest": "^4.1.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.28",
|
||||
"lucide-svelte": "^0.563.0",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"remark-parse": "^11.0.0",
|
||||
"remark-rehype": "^11.1.2",
|
||||
"svelte-spa-router": "^4.0.1",
|
||||
"unified": "^11.0.5",
|
||||
"unist-util-visit": "^5.1.0"
|
||||
}
|
||||
}
|
||||
|
Before Width: | Height: | Size: 5.9 KiB After Width: | Height: | Size: 5.9 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 38 KiB After Width: | Height: | Size: 38 KiB |
|
Before Width: | Height: | Size: 6.5 KiB After Width: | Height: | Size: 6.5 KiB |
|
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 28 KiB |
@@ -0,0 +1,58 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import Router from "svelte-spa-router";
|
||||
import Header from "./components/Header.svelte";
|
||||
import LogViewer from "./routes/LogViewer.svelte";
|
||||
import Models from "./routes/Models.svelte";
|
||||
import Activity from "./routes/Activity.svelte";
|
||||
import Playground from "./routes/Playground.svelte";
|
||||
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
||||
import { enableAPIEvents } from "./stores/api";
|
||||
import { initScreenWidth, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
||||
import { currentRoute } from "./stores/route";
|
||||
|
||||
const routes = {
|
||||
"/": PlaygroundStub,
|
||||
"/models": Models,
|
||||
"/logs": LogViewer,
|
||||
"/activity": Activity,
|
||||
"*": PlaygroundStub,
|
||||
};
|
||||
|
||||
function handleRouteLoaded(event: { detail: { route: string | RegExp } }) {
|
||||
const route = event.detail.route;
|
||||
currentRoute.set(typeof route === "string" ? route : "/");
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
document.documentElement.setAttribute("data-theme", $isDarkMode ? "dark" : "light");
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
const icon = $connectionState === "connecting" ? "\u{1F7E1}" : $connectionState === "connected" ? "\u{1F7E2}" : "\u{1F534}";
|
||||
document.title = `${icon} ${$appTitle}`;
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
const cleanupScreenWidth = initScreenWidth();
|
||||
enableAPIEvents(true);
|
||||
|
||||
return () => {
|
||||
cleanupScreenWidth();
|
||||
enableAPIEvents(false);
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-screen">
|
||||
<Header />
|
||||
|
||||
<main class="flex-1 overflow-auto p-4">
|
||||
<div class="h-full" class:hidden={$currentRoute !== "/"}>
|
||||
<Playground />
|
||||
</div>
|
||||
<div class="h-full" class:hidden={$currentRoute === "/"}>
|
||||
<Router {routes} on:routeLoaded={handleRouteLoaded} />
|
||||
</div>
|
||||
</main>
|
||||
</div>
|
||||
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 4.0 KiB After Width: | Height: | Size: 4.0 KiB |
@@ -0,0 +1,452 @@
|
||||
<script lang="ts">
|
||||
import type { ReqRespCapture } from "../lib/types";
|
||||
|
||||
interface Props {
|
||||
capture: ReqRespCapture | null;
|
||||
open: boolean;
|
||||
onclose: () => void;
|
||||
}
|
||||
|
||||
let { capture, open, onclose }: Props = $props();
|
||||
|
||||
let dialogEl: HTMLDialogElement | undefined = $state();
|
||||
|
||||
type BodyTab = "raw" | "pretty" | "chat";
|
||||
let reqBodyTab: BodyTab = $state("pretty");
|
||||
let respBodyTab: BodyTab = $state("pretty");
|
||||
let copiedReq = $state(false);
|
||||
let copiedResp = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
if (open && dialogEl) {
|
||||
dialogEl.showModal();
|
||||
} else if (!open && dialogEl) {
|
||||
dialogEl.close();
|
||||
}
|
||||
});
|
||||
|
||||
// Reset tabs when capture changes
|
||||
$effect(() => {
|
||||
if (capture) {
|
||||
const reqCt = getContentType(capture.req_headers);
|
||||
const respCt = getContentType(capture.resp_headers);
|
||||
reqBodyTab = reqCt.includes("json") ? "pretty" : "raw";
|
||||
respBodyTab = respCt.includes("text/event-stream")
|
||||
? "chat"
|
||||
: respCt.includes("json")
|
||||
? "pretty"
|
||||
: "raw";
|
||||
}
|
||||
});
|
||||
|
||||
function handleDialogClose() {
|
||||
onclose();
|
||||
}
|
||||
|
||||
function decodeBody(body: string | null | undefined): string {
|
||||
if (!body) return "";
|
||||
try {
|
||||
const binary = atob(body);
|
||||
const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
|
||||
return new TextDecoder().decode(bytes);
|
||||
} catch {
|
||||
return body;
|
||||
}
|
||||
}
|
||||
|
||||
function formatJson(str: string): string {
|
||||
try {
|
||||
const parsed = JSON.parse(str);
|
||||
return JSON.stringify(parsed, null, 2);
|
||||
} catch {
|
||||
return str;
|
||||
}
|
||||
}
|
||||
|
||||
function getContentType(
|
||||
headers: Record<string, string> | null | undefined,
|
||||
): string {
|
||||
if (!headers) return "";
|
||||
const ct = headers["Content-Type"] || headers["content-type"] || "";
|
||||
return ct.toLowerCase();
|
||||
}
|
||||
|
||||
function isImageContentType(contentType: string): boolean {
|
||||
return contentType.startsWith("image/");
|
||||
}
|
||||
|
||||
function isTextContentType(contentType: string): boolean {
|
||||
return (
|
||||
contentType.startsWith("text/") ||
|
||||
contentType.includes("application/json") ||
|
||||
contentType.includes("application/xml") ||
|
||||
contentType.includes("application/javascript")
|
||||
);
|
||||
}
|
||||
|
||||
function getImageDataUrl(body: string, contentType: string): string {
|
||||
const mimeType = contentType.split(";")[0].trim();
|
||||
return `data:${mimeType};base64,${body}`;
|
||||
}
|
||||
|
||||
interface SSEChat {
|
||||
reasoning: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
function parseSSEChat(text: string): SSEChat {
|
||||
const result: SSEChat = { reasoning: "", content: "" };
|
||||
for (const line of text.split("\n")) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || !trimmed.startsWith("data: ")) continue;
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
const delta = parsed.choices?.[0]?.delta;
|
||||
if (delta?.content) result.content += delta.content;
|
||||
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
|
||||
} catch {
|
||||
// skip unparseable lines
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
async function copyToClipboard(text: string, type: "req" | "resp") {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
if (type === "req") {
|
||||
copiedReq = true;
|
||||
setTimeout(() => (copiedReq = false), 1500);
|
||||
} else {
|
||||
copiedResp = true;
|
||||
setTimeout(() => (copiedResp = false), 1500);
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
function getCopyText(): string {
|
||||
if (respBodyTab === "chat") {
|
||||
let text = "";
|
||||
if (sseChat.reasoning) text += sseChat.reasoning + "\n\n";
|
||||
text += sseChat.content;
|
||||
return text;
|
||||
}
|
||||
return displayedResponseBody;
|
||||
}
|
||||
|
||||
// Request body derivations
|
||||
let requestContentType = $derived(
|
||||
capture ? getContentType(capture.req_headers) : "",
|
||||
);
|
||||
let isRequestJson = $derived(requestContentType.includes("json"));
|
||||
|
||||
let requestBodyRaw = $derived.by(() => {
|
||||
if (!capture) return "";
|
||||
return decodeBody(capture.req_body);
|
||||
});
|
||||
|
||||
let requestBodyPretty = $derived.by(() => {
|
||||
if (!isRequestJson) return requestBodyRaw;
|
||||
return formatJson(requestBodyRaw);
|
||||
});
|
||||
|
||||
let displayedRequestBody = $derived(
|
||||
reqBodyTab === "pretty" ? requestBodyPretty : requestBodyRaw,
|
||||
);
|
||||
|
||||
// Response body derivations
|
||||
let responseContentType = $derived(
|
||||
capture ? getContentType(capture.resp_headers) : "",
|
||||
);
|
||||
let isResponseImage = $derived(isImageContentType(responseContentType));
|
||||
let isResponseText = $derived(isTextContentType(responseContentType));
|
||||
let isResponseJson = $derived(responseContentType.includes("json"));
|
||||
let isSSE = $derived(responseContentType.includes("text/event-stream"));
|
||||
|
||||
let responseBodyRaw = $derived.by(() => {
|
||||
if (!capture) return "";
|
||||
return decodeBody(capture.resp_body);
|
||||
});
|
||||
|
||||
let responseBodyPretty = $derived.by(() => {
|
||||
if (!isResponseJson) return responseBodyRaw;
|
||||
return formatJson(responseBodyRaw);
|
||||
});
|
||||
|
||||
let sseChat = $derived.by(() => {
|
||||
if (!isSSE || !responseBodyRaw)
|
||||
return { reasoning: "", content: "" } as SSEChat;
|
||||
return parseSSEChat(responseBodyRaw);
|
||||
});
|
||||
|
||||
let displayedResponseBody = $derived.by(() => {
|
||||
if (respBodyTab === "pretty") return responseBodyPretty;
|
||||
return responseBodyRaw;
|
||||
});
|
||||
</script>
|
||||
|
||||
<dialog
|
||||
bind:this={dialogEl}
|
||||
onclose={handleDialogClose}
|
||||
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-4xl w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
|
||||
>
|
||||
{#if capture}
|
||||
<div class="flex flex-col max-h-[90vh]">
|
||||
<div
|
||||
class="flex justify-between items-center p-4 border-b border-card-border"
|
||||
>
|
||||
<h2 class="text-xl font-bold pb-0">Capture #{capture.id + 1}{#if capture.req_path} <span class="text-base font-mono font-normal text-txtsecondary">{capture.req_path}</span>{/if}</h2>
|
||||
<button
|
||||
onclick={() => dialogEl?.close()}
|
||||
class="text-txtsecondary hover:text-txtmain text-2xl leading-none"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="overflow-y-auto flex-1 p-4 space-y-4">
|
||||
<!-- Request Headers -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Request Headers
|
||||
</summary>
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
|
||||
>
|
||||
<table class="w-full text-sm">
|
||||
<tbody>
|
||||
{#each Object.entries(capture.req_headers || {}) as [key, value]}
|
||||
<tr class="border-b border-card-border-inner last:border-0">
|
||||
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
|
||||
>{key}</td
|
||||
>
|
||||
<td class="px-3 py-1 font-mono break-all">{value}</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Request Body -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Request Body
|
||||
</summary>
|
||||
{#if requestBodyRaw}
|
||||
<div class="mt-2 flex items-center justify-between">
|
||||
<div class="flex gap-1">
|
||||
{#if isRequestJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={reqBodyTab === "pretty"}
|
||||
onclick={() => (reqBodyTab = "pretty")}>Pretty</button
|
||||
>
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={reqBodyTab === "raw"}
|
||||
onclick={() => (reqBodyTab = "raw")}>Raw</button
|
||||
>
|
||||
{/if}
|
||||
</div>
|
||||
<button
|
||||
class="tab-btn"
|
||||
onclick={() =>
|
||||
copyToClipboard(displayedRequestBody, "req")}
|
||||
>
|
||||
{#if copiedReq}
|
||||
Copied!
|
||||
{:else}
|
||||
Copy
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre
|
||||
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedRequestBody}</pre>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre class="p-3 text-sm font-mono whitespace-pre-wrap break-all"
|
||||
>(empty)</pre
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</details>
|
||||
|
||||
<!-- Response Headers -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Response Headers
|
||||
</summary>
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
|
||||
>
|
||||
<table class="w-full text-sm">
|
||||
<tbody>
|
||||
{#each Object.entries(capture.resp_headers || {}) as [key, value]}
|
||||
<tr class="border-b border-card-border-inner last:border-0">
|
||||
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
|
||||
>{key}</td
|
||||
>
|
||||
<td class="px-3 py-1 font-mono break-all">{value}</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Response Body -->
|
||||
<details class="group" open>
|
||||
<summary
|
||||
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||
>
|
||||
Response Body
|
||||
</summary>
|
||||
{#if isResponseImage && capture.resp_body}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<div class="p-3 flex justify-center">
|
||||
<img
|
||||
src={getImageDataUrl(capture.resp_body, responseContentType)}
|
||||
alt="Response"
|
||||
class="max-w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{:else if isSSE || isResponseText}
|
||||
<div class="mt-2 flex items-center justify-between">
|
||||
<div class="flex gap-1">
|
||||
{#if isSSE}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "chat"}
|
||||
onclick={() => (respBodyTab = "chat")}>Chat</button
|
||||
>
|
||||
{/if}
|
||||
{#if isResponseJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "pretty"}
|
||||
onclick={() => (respBodyTab = "pretty")}>Pretty</button
|
||||
>
|
||||
{/if}
|
||||
{#if isSSE || isResponseJson}
|
||||
<button
|
||||
class="tab-btn"
|
||||
class:tab-btn-active={respBodyTab === "raw"}
|
||||
onclick={() => (respBodyTab = "raw")}>Raw</button
|
||||
>
|
||||
{/if}
|
||||
</div>
|
||||
<button
|
||||
class="tab-btn"
|
||||
onclick={() => copyToClipboard(getCopyText(), "resp")}
|
||||
>
|
||||
{#if copiedResp}
|
||||
Copied!
|
||||
{:else}
|
||||
Copy
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
{#if respBodyTab === "chat"}
|
||||
<div class="p-3 text-sm space-y-3">
|
||||
{#if sseChat.reasoning}
|
||||
<div>
|
||||
<div
|
||||
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
|
||||
>
|
||||
Reasoning
|
||||
</div>
|
||||
<pre
|
||||
class="font-mono whitespace-pre-wrap break-all text-txtsecondary">{sseChat.reasoning}</pre>
|
||||
</div>
|
||||
{/if}
|
||||
{#if sseChat.content}
|
||||
<div>
|
||||
{#if sseChat.reasoning}
|
||||
<div
|
||||
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
|
||||
>
|
||||
Response
|
||||
</div>
|
||||
{/if}
|
||||
<pre
|
||||
class="font-mono whitespace-pre-wrap break-all">{sseChat.content}</pre>
|
||||
</div>
|
||||
{/if}
|
||||
{#if !sseChat.reasoning && !sseChat.content}
|
||||
<pre class="font-mono">(empty)</pre>
|
||||
{/if}
|
||||
</div>
|
||||
{:else}
|
||||
<pre
|
||||
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedResponseBody || "(empty)"}</pre>
|
||||
{/if}
|
||||
</div>
|
||||
{:else if responseBodyRaw}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<div class="p-3 text-sm text-txtsecondary italic">
|
||||
(binary data - {responseContentType || "unknown content type"})
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||
>
|
||||
<pre class="p-3 text-sm font-mono">(empty)</pre>
|
||||
</div>
|
||||
{/if}
|
||||
</details>
|
||||
</div>
|
||||
|
||||
<div class="p-4 border-t border-card-border flex justify-end">
|
||||
<button onclick={() => dialogEl?.close()} class="btn"> Close </button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</dialog>
|
||||
|
||||
<style>
|
||||
.tab-btn {
|
||||
padding: 2px 10px;
|
||||
font-size: 0.75rem;
|
||||
border-radius: 4px;
|
||||
color: var(--color-txtsecondary);
|
||||
cursor: pointer;
|
||||
border: 1px solid transparent;
|
||||
background: transparent;
|
||||
transition: all 0.15s;
|
||||
}
|
||||
.tab-btn:hover {
|
||||
color: var(--color-txtmain);
|
||||
background: var(--color-secondary);
|
||||
}
|
||||
.tab-btn-active {
|
||||
color: var(--color-primary);
|
||||
background: color-mix(in srgb, var(--color-primary) 12%, transparent);
|
||||
border-color: color-mix(in srgb, var(--color-primary) 25%, transparent);
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,24 @@
|
||||
<script lang="ts">
|
||||
import { connectionState } from "../stores/theme";
|
||||
import { versionInfo } from "../stores/api";
|
||||
|
||||
let eventStatusColor = $derived.by(() => {
|
||||
switch ($connectionState) {
|
||||
case "connected":
|
||||
return "bg-emerald-500";
|
||||
case "connecting":
|
||||
return "bg-amber-500";
|
||||
case "disconnected":
|
||||
default:
|
||||
return "bg-red-500";
|
||||
}
|
||||
});
|
||||
|
||||
let tooltipText = $derived(
|
||||
`Event Stream: ${$connectionState ?? "unknown"}\nAPI Version: ${$versionInfo?.version ?? "unknown"}\nCommit Hash: ${$versionInfo?.commit?.substring(0, 7) ?? "unknown"}\nBuild Date: ${$versionInfo?.build_date ?? "unknown"}`
|
||||
);
|
||||
</script>
|
||||
|
||||
<div class="flex items-center" title={tooltipText}>
|
||||
<span class="inline-block w-3 h-3 rounded-full {eventStatusColor} mr-2"></span>
|
||||
</div>
|
||||
@@ -0,0 +1,120 @@
|
||||
<script lang="ts">
|
||||
import { link } from "svelte-spa-router";
|
||||
import { screenWidth, toggleTheme, isDarkMode, appTitle, isNarrow } from "../stores/theme";
|
||||
import { currentRoute } from "../stores/route";
|
||||
import { playgroundActivity } from "../stores/playgroundActivity";
|
||||
import ConnectionStatus from "./ConnectionStatus.svelte";
|
||||
|
||||
function handleTitleChange(newTitle: string): void {
|
||||
const sanitized = newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap";
|
||||
appTitle.set(sanitized);
|
||||
}
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent): void {
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
const target = e.currentTarget as HTMLElement;
|
||||
handleTitleChange(target.textContent || "(set title)");
|
||||
target.blur();
|
||||
}
|
||||
}
|
||||
|
||||
function handleBlur(e: FocusEvent): void {
|
||||
const target = e.currentTarget as HTMLElement;
|
||||
handleTitleChange(target.textContent || "(set title)");
|
||||
}
|
||||
|
||||
function isActive(path: string, current: string): boolean {
|
||||
return path === "/" ? current === "/" : current.startsWith(path);
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<header
|
||||
class="flex items-center justify-between bg-surface border-b border-border px-4 {$isNarrow
|
||||
? 'py-1 h-[60px]'
|
||||
: 'p-2 h-[75px]'}"
|
||||
>
|
||||
{#if $screenWidth !== "xs" && $screenWidth !== "sm"}
|
||||
<h1
|
||||
contenteditable="true"
|
||||
class="p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
|
||||
onblur={handleBlur}
|
||||
onkeydown={handleKeyDown}
|
||||
>
|
||||
{$appTitle}
|
||||
</h1>
|
||||
{/if}
|
||||
|
||||
<menu class="flex items-center gap-4 overflow-x-auto">
|
||||
<a
|
||||
href="/"
|
||||
use:link
|
||||
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
|
||||
>
|
||||
Playground
|
||||
</a>
|
||||
<a
|
||||
href="/models"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/models", $currentRoute)}
|
||||
>
|
||||
Models
|
||||
</a>
|
||||
<a
|
||||
href="/activity"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/activity", $currentRoute)}
|
||||
>
|
||||
Activity
|
||||
</a>
|
||||
<a
|
||||
href="/logs"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/logs", $currentRoute)}
|
||||
>
|
||||
Logs
|
||||
</a>
|
||||
<button onclick={toggleTheme} title="Toggle theme">
|
||||
{#if $isDarkMode}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M9.528 1.718a.75.75 0 0 1 .162.819A8.97 8.97 0 0 0 9 6a9 9 0 0 0 9 9 8.97 8.97 0 0 0 3.463-.69.75.75 0 0 1 .981.98 10.503 10.503 0 0 1-9.694 6.46c-5.799 0-10.5-4.7-10.5-10.5 0-4.368 2.667-8.112 6.46-9.694a.75.75 0 0 1 .818.162Z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path
|
||||
d="M12 2.25a.75.75 0 0 1 .75.75v2.25a.75.75 0 0 1-1.5 0V3a.75.75 0 0 1 .75-.75ZM7.5 12a4.5 4.5 0 1 1 9 0 4.5 4.5 0 0 1-9 0ZM18.894 6.166a.75.75 0 0 0-1.06-1.06l-1.591 1.59a.75.75 0 1 0 1.06 1.061l1.591-1.59ZM21.75 12a.75.75 0 0 1-.75.75h-2.25a.75.75 0 0 1 0-1.5H21a.75.75 0 0 1 .75.75ZM17.834 18.894a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 1 0-1.061 1.06l1.591 1.591ZM12 18a.75.75 0 0 1 .75.75V21a.75.75 0 0 1-1.5 0v-2.25A.75.75 0 0 1 12 18ZM7.758 17.303a.75.75 0 0 0-1.061-1.06l-1.591 1.59a.75.75 0 0 0 1.06 1.061l1.591-1.59ZM6 12a.75.75 0 0 1-.75.75H3a.75.75 0 0 1 0-1.5h2.25A.75.75 0 0 1 6 12ZM6.697 7.757a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 0 0-1.061 1.06l1.59 1.591Z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
<ConnectionStatus />
|
||||
</menu>
|
||||
</header>
|
||||
|
||||
<style>
|
||||
.activity-link {
|
||||
background: linear-gradient(90deg, #6366f1, #8b5cf6, #a855f7, #8b5cf6, #6366f1);
|
||||
background-size: 200% 100%;
|
||||
-webkit-background-clip: text;
|
||||
background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
animation: gradient-shift 2s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes gradient-shift {
|
||||
0% {
|
||||
background-position: 0% 50%;
|
||||
}
|
||||
100% {
|
||||
background-position: 200% 50%;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,139 @@
|
||||
<script lang="ts">
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
title: string;
|
||||
logData: string;
|
||||
}
|
||||
|
||||
let { id, title, logData }: Props = $props();
|
||||
|
||||
let filterRegex = $state("");
|
||||
|
||||
// Create persistent stores for this panel (id is intentionally captured at init time)
|
||||
// svelte-ignore state_referenced_locally
|
||||
const fontSizeStore = persistentStore<"xxs" | "xs" | "small" | "normal">(`logPanel-${id}-fontSize`, "normal");
|
||||
// svelte-ignore state_referenced_locally
|
||||
const wrapTextStore = persistentStore<boolean>(`logPanel-${id}-wrapText`, false);
|
||||
// svelte-ignore state_referenced_locally
|
||||
const showFilterStore = persistentStore<boolean>(`logPanel-${id}-showFilter`, false);
|
||||
|
||||
let textWrapClass = $derived($wrapTextStore ? "whitespace-pre-wrap" : "whitespace-pre");
|
||||
|
||||
function toggleFontSize(): void {
|
||||
fontSizeStore.update((prev) => {
|
||||
switch (prev) {
|
||||
case "xxs": return "xs";
|
||||
case "xs": return "small";
|
||||
case "small": return "normal";
|
||||
case "normal": return "xxs";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function toggleWrapText(): void {
|
||||
wrapTextStore.update((prev) => !prev);
|
||||
}
|
||||
|
||||
function toggleFilter(): void {
|
||||
if ($showFilterStore) {
|
||||
showFilterStore.set(false);
|
||||
filterRegex = "";
|
||||
} else {
|
||||
showFilterStore.set(true);
|
||||
}
|
||||
}
|
||||
|
||||
let fontSizeClass = $derived.by(() => {
|
||||
switch ($fontSizeStore) {
|
||||
case "xxs": return "text-[0.5rem]";
|
||||
case "xs": return "text-[0.75rem]";
|
||||
case "small": return "text-[0.875rem]";
|
||||
case "normal": return "text-base";
|
||||
}
|
||||
});
|
||||
|
||||
let filteredLogs = $derived.by(() => {
|
||||
if (!filterRegex) return logData;
|
||||
try {
|
||||
const regex = new RegExp(filterRegex, "i");
|
||||
return logData.split("\n").filter((line) => regex.test(line)).join("\n");
|
||||
} catch {
|
||||
return logData;
|
||||
}
|
||||
});
|
||||
|
||||
let preElement: HTMLPreElement;
|
||||
let userScrolledUp = $state(false);
|
||||
|
||||
function handleScroll() {
|
||||
if (!preElement) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = preElement;
|
||||
userScrolledUp = scrollHeight - scrollTop - clientHeight > 40;
|
||||
}
|
||||
|
||||
// Auto scroll to bottom when logs change, unless user has scrolled up
|
||||
$effect(() => {
|
||||
if (preElement && filteredLogs && !userScrolledUp) {
|
||||
preElement.scrollTop = preElement.scrollHeight;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="rounded-lg overflow-hidden flex flex-col bg-gray-950/5 dark:bg-white/10 h-full w-full p-1">
|
||||
<div class="p-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<h3 class="m-0 text-lg p-0">{title}</h3>
|
||||
|
||||
<div class="flex gap-2 items-center">
|
||||
<button class="btn border-0" onclick={toggleFontSize} title="Change font size">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path d="M2 4v3h5v12h3V7h5V4H2zm19 5h-9v3h3v7h3v-7h3V9z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button class="btn border-0" onclick={toggleWrapText} title="Toggle text wrap">
|
||||
{#if $wrapTextStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h10.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
<button class="btn border-0" onclick={toggleFilter} title="Toggle filter">
|
||||
{#if $showFilterStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||
<path fill-rule="evenodd" d="M10.5 3.75a6.75 6.75 0 1 0 0 13.5 6.75 6.75 0 0 0 0-13.5ZM2.25 10.5a8.25 8.25 0 1 1 14.59 5.28l4.69 4.69a.75.75 0 1 1-1.06 1.06l-4.69-4.69A8.25 8.25 0 0 1 2.25 10.5Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-4 h-4">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="m21 21-5.197-5.197m0 0A7.5 7.5 0 1 0 5.196 5.196a7.5 7.5 0 0 0 10.607 10.607Z" />
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if $showFilterStore}
|
||||
<div class="mt-2 flex gap-2 items-center w-full">
|
||||
<input
|
||||
type="text"
|
||||
class="w-full text-sm border border-gray-950/10 dark:border-white/5 p-2 rounded outline-none"
|
||||
placeholder="Filter logs (regex)..."
|
||||
bind:value={filterRegex}
|
||||
/>
|
||||
<button class="pl-2" onclick={() => (filterRegex = "")} aria-label="Clear filter">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm-1.72 6.97a.75.75 0 1 0-1.06 1.06L10.94 12l-1.72 1.72a.75.75 0 1 0 1.06 1.06L12 13.06l1.72 1.72a.75.75 0 1 0 1.06-1.06L13.06 12l1.72-1.72a.75.75 0 1 0-1.06-1.06L12 10.94l-1.72-1.72Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="rounded-lg bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||
<pre bind:this={preElement} onscroll={handleScroll} class="{textWrapClass} {fontSizeClass} h-full overflow-auto p-4">{filteredLogs}</pre>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,211 @@
|
||||
<script lang="ts">
|
||||
import { models, loadModel, unloadAllModels, unloadSingleModel } from "../stores/api";
|
||||
import { isNarrow } from "../stores/theme";
|
||||
import { persistentStore } from "../stores/persistent";
|
||||
import type { Model } from "../lib/types";
|
||||
|
||||
let isUnloading = $state(false);
|
||||
let menuOpen = $state(false);
|
||||
|
||||
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
|
||||
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
|
||||
|
||||
let filteredModels = $derived.by(() => {
|
||||
const filtered = $models.filter((model) => $showUnlistedStore || !model.unlisted);
|
||||
const peerModels = filtered.filter((m) => m.peerID);
|
||||
|
||||
// Group peer models by peerID
|
||||
const grouped = peerModels.reduce(
|
||||
(acc, model) => {
|
||||
const peerId = model.peerID || "unknown";
|
||||
if (!acc[peerId]) acc[peerId] = [];
|
||||
acc[peerId].push(model);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, Model[]>
|
||||
);
|
||||
|
||||
return {
|
||||
regularModels: filtered.filter((m) => !m.peerID),
|
||||
peerModelsByPeerId: grouped,
|
||||
};
|
||||
});
|
||||
|
||||
async function handleUnloadAllModels(): Promise<void> {
|
||||
isUnloading = true;
|
||||
try {
|
||||
await unloadAllModels();
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
} finally {
|
||||
setTimeout(() => (isUnloading = false), 1000);
|
||||
}
|
||||
}
|
||||
|
||||
function toggleIdorName(): void {
|
||||
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
|
||||
}
|
||||
|
||||
function toggleShowUnlisted(): void {
|
||||
showUnlistedStore.update((prev) => !prev);
|
||||
}
|
||||
|
||||
function getModelDisplay(model: Model): string {
|
||||
return $showIdorNameStore === "id" ? model.id : (model.name || model.id);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="card h-full flex flex-col">
|
||||
<div class="shrink-0">
|
||||
<div class="flex justify-between items-baseline">
|
||||
<h2 class={$isNarrow ? "text-xl" : ""}>Models</h2>
|
||||
{#if $isNarrow}
|
||||
<div class="relative">
|
||||
<button class="btn text-base flex items-center gap-2 py-1" onclick={() => (menuOpen = !menuOpen)} aria-label="Toggle menu">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
{#if menuOpen}
|
||||
<div class="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { toggleIdorName(); menuOpen = false; }}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{$showIdorNameStore === "id" ? "Show Name" : "Show ID"}
|
||||
</button>
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { toggleShowUnlisted(); menuOpen = false; }}
|
||||
>
|
||||
{#if $showUnlistedStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
|
||||
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
|
||||
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
|
||||
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{/if}
|
||||
{$showUnlistedStore ? "Hide Unlisted" : "Show Unlisted"}
|
||||
</button>
|
||||
<button
|
||||
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||
onclick={() => { handleUnloadAllModels(); menuOpen = false; }}
|
||||
disabled={isUnloading}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{#if !$isNarrow}
|
||||
<div class="flex justify-between">
|
||||
<div class="flex gap-2">
|
||||
<button class="btn text-base flex items-center gap-2" onclick={toggleIdorName} style="line-height: 1.2">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{$showIdorNameStore === "id" ? "ID" : "Name"}
|
||||
</button>
|
||||
|
||||
<button class="btn text-base flex items-center gap-2" onclick={toggleShowUnlisted} style="line-height: 1.2">
|
||||
{#if $showUnlistedStore}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
|
||||
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
|
||||
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
|
||||
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
|
||||
</svg>
|
||||
{/if}
|
||||
unlisted
|
||||
</button>
|
||||
</div>
|
||||
<button class="btn text-base flex items-center gap-2" onclick={handleUnloadAllModels} disabled={isUnloading}>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{isUnloading ? "Unloading..." : "Unload All"}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex-1 overflow-y-auto">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 bg-card z-10">
|
||||
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||
<th>{$showIdorNameStore === "id" ? "Model ID" : "Name"}</th>
|
||||
<th></th>
|
||||
<th>State</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each filteredModels.regularModels as model (model.id)}
|
||||
<tr class="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td class={model.unlisted ? "text-txtsecondary" : ""}>
|
||||
<a href="/upstream/{model.id}/" class="font-semibold" target="_blank">
|
||||
{getModelDisplay(model)}
|
||||
</a>
|
||||
{#if model.description}
|
||||
<p class={model.unlisted ? "text-opacity-70" : ""}><em>{model.description}</em></p>
|
||||
{/if}
|
||||
{#if model.aliases && model.aliases.length > 0}
|
||||
<p class="text-xs text-txtsecondary">Aliases: {model.aliases.join(", ")}</p>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="w-12">
|
||||
{#if model.state === "stopped"}
|
||||
<button class="btn btn--sm" onclick={() => loadModel(model.id)}>Load</button>
|
||||
{:else}
|
||||
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="w-20">
|
||||
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{#if Object.keys(filteredModels.peerModelsByPeerId).length > 0}
|
||||
<h3 class="mt-8 mb-2">Peer Models</h3>
|
||||
{#each Object.entries(filteredModels.peerModelsByPeerId).sort(([a], [b]) => a.localeCompare(b)) as [peerId, peerModels] (peerId)}
|
||||
<div class="mb-4">
|
||||
<table class="w-full">
|
||||
<thead class="sticky top-0 bg-card z-10">
|
||||
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||
<th class="font-semibold">{peerId}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each peerModels as model (model.id)}
|
||||
<tr class="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td class="pl-8 {model.unlisted ? 'text-txtsecondary' : ''}">
|
||||
<span>{model.id}</span>
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,152 @@
|
||||
<script lang="ts">
|
||||
import type { Snippet } from "svelte";
|
||||
import { onMount } from "svelte";
|
||||
|
||||
interface Props {
|
||||
direction: "horizontal" | "vertical";
|
||||
storageKey: string;
|
||||
leftPanel: Snippet;
|
||||
rightPanel: Snippet;
|
||||
defaultSize?: number;
|
||||
minSize?: number;
|
||||
}
|
||||
|
||||
let { direction, storageKey, leftPanel, rightPanel, defaultSize = 50, minSize = 5 }: Props = $props();
|
||||
|
||||
let containerRef: HTMLDivElement;
|
||||
let isDragging = $state(false);
|
||||
// svelte-ignore state_referenced_locally
|
||||
let leftSize = $state(defaultSize);
|
||||
|
||||
// Load saved size from localStorage
|
||||
onMount(() => {
|
||||
const saved = localStorage.getItem(`panel-size-${storageKey}`);
|
||||
if (saved) {
|
||||
const parsed = parseFloat(saved);
|
||||
if (!isNaN(parsed) && parsed >= minSize && parsed <= 100 - minSize) {
|
||||
leftSize = parsed;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
function saveSize(): void {
|
||||
localStorage.setItem(`panel-size-${storageKey}`, String(leftSize));
|
||||
}
|
||||
|
||||
function handleMouseDown(e: MouseEvent): void {
|
||||
e.preventDefault();
|
||||
isDragging = true;
|
||||
document.addEventListener("mousemove", handleMouseMove);
|
||||
document.addEventListener("mouseup", handleMouseUp);
|
||||
}
|
||||
|
||||
function handleTouchStart(_e: TouchEvent): void {
|
||||
isDragging = true;
|
||||
document.addEventListener("touchmove", handleTouchMove);
|
||||
document.addEventListener("touchend", handleTouchEnd);
|
||||
}
|
||||
|
||||
function handleMouseMove(e: MouseEvent): void {
|
||||
if (!isDragging || !containerRef) return;
|
||||
updateSize(e.clientX, e.clientY);
|
||||
}
|
||||
|
||||
function handleTouchMove(e: TouchEvent): void {
|
||||
if (!isDragging || !containerRef || e.touches.length === 0) return;
|
||||
updateSize(e.touches[0].clientX, e.touches[0].clientY);
|
||||
}
|
||||
|
||||
function updateSize(clientX: number, clientY: number): void {
|
||||
const rect = containerRef.getBoundingClientRect();
|
||||
|
||||
let newSize: number;
|
||||
if (direction === "horizontal") {
|
||||
newSize = ((clientX - rect.left) / rect.width) * 100;
|
||||
} else {
|
||||
newSize = ((clientY - rect.top) / rect.height) * 100;
|
||||
}
|
||||
|
||||
// Clamp size
|
||||
newSize = Math.max(minSize, Math.min(100 - minSize, newSize));
|
||||
leftSize = newSize;
|
||||
}
|
||||
|
||||
function handleMouseUp(): void {
|
||||
isDragging = false;
|
||||
saveSize();
|
||||
document.removeEventListener("mousemove", handleMouseMove);
|
||||
document.removeEventListener("mouseup", handleMouseUp);
|
||||
}
|
||||
|
||||
function handleTouchEnd(): void {
|
||||
isDragging = false;
|
||||
saveSize();
|
||||
document.removeEventListener("touchmove", handleTouchMove);
|
||||
document.removeEventListener("touchend", handleTouchEnd);
|
||||
}
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent): void {
|
||||
const step = 2; // 2% increment for keyboard navigation
|
||||
const key = e.key;
|
||||
|
||||
if (direction === "horizontal" && (key === "ArrowLeft" || key === "ArrowRight")) {
|
||||
e.preventDefault();
|
||||
const delta = key === "ArrowLeft" ? -step : step;
|
||||
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
|
||||
leftSize = newSize;
|
||||
saveSize();
|
||||
} else if (direction === "vertical" && (key === "ArrowUp" || key === "ArrowDown")) {
|
||||
e.preventDefault();
|
||||
const delta = key === "ArrowUp" ? -step : step;
|
||||
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
|
||||
leftSize = newSize;
|
||||
saveSize();
|
||||
}
|
||||
}
|
||||
|
||||
let containerClass = $derived(direction === "horizontal" ? "flex-row" : "flex-col");
|
||||
|
||||
let handleClass = $derived(
|
||||
direction === "horizontal"
|
||||
? "w-2 h-full cursor-col-resize"
|
||||
: "w-full h-2 cursor-row-resize"
|
||||
);
|
||||
|
||||
let leftStyle = $derived(
|
||||
direction === "horizontal"
|
||||
? `width: ${leftSize}%; min-width: ${minSize}%`
|
||||
: `height: ${leftSize}%; min-height: ${minSize}%`
|
||||
);
|
||||
|
||||
let rightStyle = $derived(
|
||||
direction === "horizontal"
|
||||
? `width: ${100 - leftSize}%; min-width: ${minSize}%`
|
||||
: `height: ${100 - leftSize}%; min-height: ${minSize}%`
|
||||
);
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="flex {containerClass} h-full w-full gap-2">
|
||||
<div style={leftStyle} class="overflow-hidden">
|
||||
{@render leftPanel()}
|
||||
</div>
|
||||
|
||||
<!-- svelte-ignore a11y_no_noninteractive_tabindex -->
|
||||
<!-- svelte-ignore a11y_no_noninteractive_element_interactions -->
|
||||
<div
|
||||
role="separator"
|
||||
tabindex="0"
|
||||
class="{handleClass} bg-primary hover:bg-success transition-colors rounded flex-shrink-0"
|
||||
onmousedown={handleMouseDown}
|
||||
ontouchstart={handleTouchStart}
|
||||
onkeydown={handleKeyDown}
|
||||
aria-label="Resize panels"
|
||||
aria-orientation={direction}
|
||||
aria-valuenow={Math.round(leftSize)}
|
||||
aria-valuemin={minSize}
|
||||
aria-valuemax={100 - minSize}
|
||||
></div>
|
||||
|
||||
<div style={rightStyle} class="overflow-hidden">
|
||||
{@render rightPanel()}
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,167 @@
|
||||
<script lang="ts">
|
||||
import { inFlightRequests, metrics } from "../stores/api";
|
||||
import TokenHistogram from "./TokenHistogram.svelte";
|
||||
|
||||
interface HistogramData {
|
||||
bins: number[];
|
||||
min: number;
|
||||
max: number;
|
||||
binSize: number;
|
||||
p99: number;
|
||||
p95: number;
|
||||
p50: number;
|
||||
}
|
||||
|
||||
let stats = $derived.by(() => {
|
||||
const totalRequests = $metrics.length;
|
||||
if (totalRequests === 0) {
|
||||
return {
|
||||
totalRequests: 0,
|
||||
totalInputTokens: 0,
|
||||
totalOutputTokens: 0,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: { p99: "0", p95: "0", p50: "0" },
|
||||
histogramData: null,
|
||||
};
|
||||
}
|
||||
|
||||
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||
|
||||
// Calculate token statistics using output_tokens and duration_ms
|
||||
const validMetrics = $metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
|
||||
if (validMetrics.length === 0) {
|
||||
return {
|
||||
totalRequests,
|
||||
totalInputTokens,
|
||||
totalOutputTokens,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: { p99: "0", p95: "0", p50: "0" },
|
||||
histogramData: null,
|
||||
};
|
||||
}
|
||||
|
||||
// Calculate tokens/second for each valid metric
|
||||
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
|
||||
|
||||
// Sort for percentile calculation
|
||||
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
|
||||
|
||||
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
|
||||
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
|
||||
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
|
||||
|
||||
// Create histogram data
|
||||
const min = Math.min(...tokensPerSecond);
|
||||
const max = Math.max(...tokensPerSecond);
|
||||
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5)));
|
||||
const binSize = (max - min) / binCount;
|
||||
|
||||
const bins = Array(binCount).fill(0);
|
||||
tokensPerSecond.forEach((value) => {
|
||||
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
|
||||
bins[binIndex]++;
|
||||
});
|
||||
|
||||
const histogramData: HistogramData = {
|
||||
bins,
|
||||
min,
|
||||
max,
|
||||
binSize,
|
||||
p99,
|
||||
p95,
|
||||
p50,
|
||||
};
|
||||
|
||||
return {
|
||||
totalRequests,
|
||||
totalInputTokens,
|
||||
totalOutputTokens,
|
||||
inFlightRequests: $inFlightRequests,
|
||||
tokenStats: {
|
||||
p99: p99.toFixed(2),
|
||||
p95: p95.toFixed(2),
|
||||
p50: p50.toFixed(2),
|
||||
},
|
||||
histogramData,
|
||||
};
|
||||
});
|
||||
|
||||
const nf = new Intl.NumberFormat();
|
||||
</script>
|
||||
|
||||
<div class="card">
|
||||
<div class="rounded-lg overflow-hidden border border-card-border-inner">
|
||||
<table class="min-w-full divide-y divide-card-border-inner">
|
||||
<thead class="bg-secondary">
|
||||
<tr>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">Requests</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Processed
|
||||
</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Generated
|
||||
</th>
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
||||
Token Stats (tokens/sec)
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
|
||||
<tbody class="bg-surface divide-y divide-card-border-inner">
|
||||
<tr class="hover:bg-secondary">
|
||||
<td class="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">
|
||||
<div class="flex flex-col gap-1">
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Completed: {nf.format(stats.totalRequests)}</span>
|
||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Waiting: {nf.format(stats.inFlightRequests)}</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-sm font-medium">{nf.format(stats.totalInputTokens)}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-sm font-medium">{nf.format(stats.totalOutputTokens)}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-4 py-4 border-l border-gray-200 dark:border-white/10">
|
||||
<div class="space-y-3">
|
||||
<div class="grid grid-cols-3 gap-2 items-center">
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P50</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p50}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P95</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p95}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="text-center">
|
||||
<div class="text-xs text-gray-500 dark:text-gray-400">P99</div>
|
||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
||||
{stats.tokenStats.p99}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{#if stats.histogramData}
|
||||
<TokenHistogram data={stats.histogramData} />
|
||||
{/if}
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,129 @@
|
||||
<script lang="ts">
|
||||
interface HistogramData {
|
||||
bins: number[];
|
||||
min: number;
|
||||
max: number;
|
||||
binSize: number;
|
||||
p99: number;
|
||||
p95: number;
|
||||
p50: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
data: HistogramData;
|
||||
}
|
||||
|
||||
let { data }: Props = $props();
|
||||
|
||||
const height = 120;
|
||||
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
|
||||
const viewBoxWidth = 600;
|
||||
const chartWidth = viewBoxWidth - padding.left - padding.right;
|
||||
const chartHeight = height - padding.top - padding.bottom;
|
||||
|
||||
let maxCount = $derived(Math.max(...data.bins));
|
||||
let barWidth = $derived(chartWidth / data.bins.length);
|
||||
let range = $derived(data.max - data.min);
|
||||
|
||||
function getXPosition(value: number): number {
|
||||
return padding.left + ((value - data.min) / range) * chartWidth;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="mt-2 w-full">
|
||||
<svg viewBox="0 0 {viewBoxWidth} {height}" class="w-full h-auto" preserveAspectRatio="xMidYMid meet">
|
||||
<!-- Y-axis -->
|
||||
<line
|
||||
x1={padding.left}
|
||||
y1={padding.top}
|
||||
x2={padding.left}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="1"
|
||||
opacity="0.3"
|
||||
/>
|
||||
|
||||
<!-- X-axis -->
|
||||
<line
|
||||
x1={padding.left}
|
||||
y1={height - padding.bottom}
|
||||
x2={viewBoxWidth - padding.right}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="1"
|
||||
opacity="0.3"
|
||||
/>
|
||||
|
||||
<!-- Histogram bars -->
|
||||
{#each data.bins as count, i}
|
||||
{@const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0}
|
||||
{@const x = padding.left + i * barWidth}
|
||||
{@const y = height - padding.bottom - barHeight}
|
||||
{@const binStart = data.min + i * data.binSize}
|
||||
{@const binEnd = binStart + data.binSize}
|
||||
<g>
|
||||
<rect
|
||||
{x}
|
||||
{y}
|
||||
width={Math.max(barWidth - 1, 1)}
|
||||
height={barHeight}
|
||||
fill="currentColor"
|
||||
opacity="0.6"
|
||||
class="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
|
||||
/>
|
||||
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
|
||||
</g>
|
||||
{/each}
|
||||
|
||||
<!-- Percentile lines -->
|
||||
<line
|
||||
x1={getXPosition(data.p50)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p50)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-gray-600 dark:text-gray-400"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1={getXPosition(data.p95)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p95)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-orange-500 dark:text-orange-400"
|
||||
/>
|
||||
|
||||
<line
|
||||
x1={getXPosition(data.p99)}
|
||||
y1={padding.top}
|
||||
x2={getXPosition(data.p99)}
|
||||
y2={height - padding.bottom}
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-dasharray="4 2"
|
||||
opacity="0.7"
|
||||
class="text-green-500 dark:text-green-400"
|
||||
/>
|
||||
|
||||
<!-- X-axis labels -->
|
||||
<text x={padding.left} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="start">
|
||||
{data.min.toFixed(1)}
|
||||
</text>
|
||||
|
||||
<text x={viewBoxWidth - padding.right} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="end">
|
||||
{data.max.toFixed(1)}
|
||||
</text>
|
||||
|
||||
<!-- X-axis label -->
|
||||
<text x={padding.left + chartWidth / 2} y={height - 2} font-size="10" fill="currentColor" opacity="0.6" text-anchor="middle">
|
||||
Tokens/Second Distribution
|
||||
</text>
|
||||
</svg>
|
||||
</div>
|
||||
@@ -0,0 +1,20 @@
|
||||
<script lang="ts">
|
||||
interface Props {
|
||||
content: string;
|
||||
}
|
||||
|
||||
let { content }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative group inline-block">
|
||||
<span class="cursor-help">ⓘ</span>
|
||||
<div
|
||||
class="absolute top-full left-1/2 transform -translate-x-1/2 mt-2
|
||||
px-3 py-2 bg-gray-900 text-white text-sm rounded-md
|
||||
opacity-0 group-hover:opacity-100 transition-opacity
|
||||
duration-200 pointer-events-none whitespace-nowrap z-50 normal-case"
|
||||
>
|
||||
{content}
|
||||
<div class="absolute bottom-full left-1/2 transform -translate-x-1/2 border-4 border-transparent border-b-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,256 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { transcribeAudio } from "../../lib/audioApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-audio-model", "");
|
||||
|
||||
let selectedFile = $state<File | null>(null);
|
||||
let isTranscribing = $state(false);
|
||||
let transcriptionResult = $state<string | null>(null);
|
||||
let error = $state<string | null>(null);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let isDragging = $state(false);
|
||||
let fileInput = $state<HTMLInputElement | null>(null);
|
||||
let copied = $state(false);
|
||||
|
||||
const ACCEPTED_FORMATS = ['.mp3', '.wav', '.ogg'];
|
||||
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
|
||||
let canTranscribe = $derived(selectedFile !== null && $selectedModelStore !== "" && !isTranscribing);
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.audioTranscribing.set(isTranscribing);
|
||||
});
|
||||
|
||||
function validateFile(file: File): { valid: boolean; error?: string } {
|
||||
const ext = '.' + file.name.split('.').pop()?.toLowerCase();
|
||||
|
||||
if (!ACCEPTED_FORMATS.includes(ext)) {
|
||||
return { valid: false, error: 'Invalid file type. Accepted: MP3, WAV, OGG' };
|
||||
}
|
||||
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
return { valid: false, error: 'File too large. Maximum: 25MB' };
|
||||
}
|
||||
|
||||
return { valid: true };
|
||||
}
|
||||
|
||||
function handleFileSelect(event: Event) {
|
||||
const target = event.target as HTMLInputElement;
|
||||
const file = target.files?.[0];
|
||||
if (file) {
|
||||
const validation = validateFile(file);
|
||||
if (validation.valid) {
|
||||
selectedFile = file;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
} else {
|
||||
error = validation.error || "Invalid file";
|
||||
selectedFile = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
isDragging = true;
|
||||
}
|
||||
|
||||
function handleDragLeave() {
|
||||
isDragging = false;
|
||||
}
|
||||
|
||||
function handleDrop(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
isDragging = false;
|
||||
|
||||
const file = event.dataTransfer?.files[0];
|
||||
if (file) {
|
||||
const validation = validateFile(file);
|
||||
if (validation.valid) {
|
||||
selectedFile = file;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
} else {
|
||||
error = validation.error || "Invalid file";
|
||||
selectedFile = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function transcribe() {
|
||||
if (!selectedFile || !$selectedModelStore || isTranscribing) return;
|
||||
|
||||
isTranscribing = true;
|
||||
error = null;
|
||||
transcriptionResult = null;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
const response = await transcribeAudio(
|
||||
$selectedModelStore,
|
||||
selectedFile,
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
transcriptionResult = response.text;
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// User cancelled
|
||||
} else {
|
||||
error = err instanceof Error ? err.message : "An error occurred";
|
||||
}
|
||||
} finally {
|
||||
isTranscribing = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function cancelTranscription() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function clearAll() {
|
||||
selectedFile = null;
|
||||
transcriptionResult = null;
|
||||
error = null;
|
||||
if (fileInput) {
|
||||
fileInput.value = '';
|
||||
}
|
||||
}
|
||||
|
||||
function copyToClipboard() {
|
||||
if (transcriptionResult) {
|
||||
navigator.clipboard.writeText(transcriptionResult);
|
||||
copied = true;
|
||||
setTimeout(() => {
|
||||
copied = false;
|
||||
}, 2000);
|
||||
}
|
||||
}
|
||||
|
||||
function formatFileSize(bytes: number): string {
|
||||
if (bytes < 1024) return bytes + ' B';
|
||||
if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB';
|
||||
return (bytes / (1024 * 1024)).toFixed(1) + ' MB';
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} />
|
||||
</div>
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to transcribe audio.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- File upload / Result display area -->
|
||||
<div class="flex-1 overflow-auto mb-4 flex items-center justify-center bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
{#if isTranscribing}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
|
||||
<p>Transcribing audio...</p>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="text-center text-red-500 p-4">
|
||||
<p class="font-medium">Error</p>
|
||||
<p class="text-sm mt-1">{error}</p>
|
||||
</div>
|
||||
{:else if transcriptionResult}
|
||||
<div class="w-full h-full flex flex-col p-4">
|
||||
<div class="flex justify-between items-center mb-2">
|
||||
<h3 class="font-medium">Transcription Result</h3>
|
||||
<button
|
||||
class="btn btn-sm"
|
||||
onclick={copyToClipboard}
|
||||
title={copied ? 'Copied!' : 'Copy to clipboard'}
|
||||
>
|
||||
{#if copied}
|
||||
<svg class="w-5 h-5 text-green-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"></path>
|
||||
</svg>
|
||||
{:else}
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"></path>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<div class="flex-1 overflow-auto p-3 rounded border border-gray-200 dark:border-white/10 bg-background whitespace-pre-wrap">
|
||||
{transcriptionResult}
|
||||
</div>
|
||||
</div>
|
||||
{:else if selectedFile}
|
||||
<div class="text-center text-txtsecondary p-4">
|
||||
<p class="font-medium mb-2">File Selected</p>
|
||||
<p class="text-sm">{selectedFile.name}</p>
|
||||
<p class="text-xs mt-1">{formatFileSize(selectedFile.size)}</p>
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
role="region"
|
||||
aria-label="Audio file drop zone"
|
||||
class="w-full h-full flex items-center justify-center text-center text-txtsecondary p-8 {isDragging ? 'bg-primary/10' : ''}"
|
||||
ondragover={handleDragOver}
|
||||
ondragleave={handleDragLeave}
|
||||
ondrop={handleDrop}
|
||||
>
|
||||
<div>
|
||||
<p class="mb-2">Drag and drop an audio file here</p>
|
||||
<p class="text-sm">or use the Browse button below</p>
|
||||
<p class="text-xs mt-4">Accepted formats: MP3, WAV, OGG (max 25MB)</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- File input and transcribe button -->
|
||||
<div class="shrink-0 flex gap-2">
|
||||
<input
|
||||
type="file"
|
||||
accept=".mp3,.wav,.ogg"
|
||||
class="hidden"
|
||||
onchange={handleFileSelect}
|
||||
bind:this={fileInput}
|
||||
/>
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => fileInput?.click()}
|
||||
disabled={isTranscribing}
|
||||
>
|
||||
Browse Files
|
||||
</button>
|
||||
<div class="flex-1"></div>
|
||||
{#if isTranscribing}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancelTranscription}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={transcribe}
|
||||
disabled={!canTranscribe}
|
||||
>
|
||||
Transcribe
|
||||
</button>
|
||||
<button
|
||||
class="btn"
|
||||
onclick={clearAll}
|
||||
disabled={!selectedFile && !transcriptionResult && !error}
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,466 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { streamChatCompletion } from "../../lib/chatApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import type { ChatMessage, ContentPart } from "../../lib/types";
|
||||
import ChatMessageComponent from "./ChatMessage.svelte";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
import ExpandableTextarea from "./ExpandableTextarea.svelte";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-selected-model", "");
|
||||
const systemPromptStore = persistentStore<string>("playground-system-prompt", "");
|
||||
const temperatureStore = persistentStore<number>("playground-temperature", 0.7);
|
||||
|
||||
function loadMessages(): ChatMessage[] {
|
||||
try {
|
||||
const saved = localStorage.getItem("playground-messages");
|
||||
return saved ? JSON.parse(saved) : [];
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
let messages = $state<ChatMessage[]>(loadMessages());
|
||||
let userInput = $state("");
|
||||
let isStreaming = $state(false);
|
||||
let isReasoning = $state(false);
|
||||
let reasoningStartTime = $state<number>(0);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let messagesContainer: HTMLDivElement | undefined = $state();
|
||||
let showSettings = $state(false);
|
||||
let attachedImages = $state<string[]>([]);
|
||||
let fileInput = $state<HTMLInputElement | null>(null);
|
||||
let imageError = $state<string | null>(null);
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
let userScrolledUp = $state(false);
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.chatStreaming.set(isStreaming);
|
||||
});
|
||||
|
||||
function handleMessagesScroll() {
|
||||
if (!messagesContainer) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = messagesContainer;
|
||||
// Consider "at bottom" if within 40px of the bottom
|
||||
userScrolledUp = scrollHeight - scrollTop - clientHeight > 40;
|
||||
}
|
||||
|
||||
// Auto-scroll when messages change — skip if user scrolled up
|
||||
$effect(() => {
|
||||
if (messages.length > 0 && messagesContainer && !userScrolledUp) {
|
||||
messagesContainer.scrollTo({
|
||||
top: messagesContainer.scrollHeight,
|
||||
behavior: isStreaming ? "instant" : "smooth",
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Persist messages to localStorage (throttled to once per 2s)
|
||||
let lastSaveTime = 0;
|
||||
$effect(() => {
|
||||
const json = JSON.stringify(messages);
|
||||
const elapsed = Date.now() - lastSaveTime;
|
||||
const save = () => {
|
||||
try { localStorage.setItem("playground-messages", json); } catch {}
|
||||
lastSaveTime = Date.now();
|
||||
};
|
||||
if (elapsed >= 2000) {
|
||||
save();
|
||||
return;
|
||||
}
|
||||
const timer = setTimeout(save, 2000 - elapsed);
|
||||
return () => clearTimeout(timer);
|
||||
});
|
||||
|
||||
async function sendMessage() {
|
||||
const trimmedInput = userInput.trim();
|
||||
if ((!trimmedInput && attachedImages.length === 0) || !$selectedModelStore || isStreaming) return;
|
||||
|
||||
userScrolledUp = false;
|
||||
|
||||
// Build message content (multimodal if images attached)
|
||||
let content: string | ContentPart[];
|
||||
if (attachedImages.length > 0) {
|
||||
const parts: ContentPart[] = [];
|
||||
if (trimmedInput) {
|
||||
parts.push({ type: "text", text: trimmedInput });
|
||||
}
|
||||
for (const url of attachedImages) {
|
||||
parts.push({ type: "image_url", image_url: { url } });
|
||||
}
|
||||
content = parts;
|
||||
} else {
|
||||
content = trimmedInput;
|
||||
}
|
||||
|
||||
// Add user message
|
||||
messages = [...messages, { role: "user", content }];
|
||||
userInput = "";
|
||||
attachedImages = [];
|
||||
imageError = null;
|
||||
|
||||
// Generate response from the new user message
|
||||
await regenerateFromIndex(messages.length - 1);
|
||||
}
|
||||
|
||||
function cancelStreaming() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function newChat() {
|
||||
if (isStreaming) {
|
||||
cancelStreaming();
|
||||
}
|
||||
messages = [];
|
||||
isReasoning = false;
|
||||
reasoningStartTime = 0;
|
||||
}
|
||||
|
||||
async function regenerateFromIndex(idx: number) {
|
||||
// Remove all messages after the edited user message
|
||||
messages = messages.slice(0, idx + 1);
|
||||
|
||||
// Add empty assistant message for the new response
|
||||
messages = [...messages, { role: "assistant", content: "" }];
|
||||
|
||||
isStreaming = true;
|
||||
isReasoning = false;
|
||||
reasoningStartTime = 0;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
// Build messages array with optional system prompt
|
||||
const apiMessages: ChatMessage[] = [];
|
||||
if ($systemPromptStore.trim()) {
|
||||
apiMessages.push({ role: "system", content: $systemPromptStore.trim() });
|
||||
}
|
||||
apiMessages.push(...messages.slice(0, -1)); // Add all messages except the empty assistant one
|
||||
|
||||
const stream = streamChatCompletion(
|
||||
$selectedModelStore,
|
||||
apiMessages,
|
||||
abortController.signal,
|
||||
{ temperature: $temperatureStore }
|
||||
);
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.done) break;
|
||||
|
||||
// Handle reasoning content
|
||||
if (chunk.reasoning_content) {
|
||||
// Start timing on first reasoning content
|
||||
if (!isReasoning) {
|
||||
isReasoning = true;
|
||||
reasoningStartTime = Date.now();
|
||||
}
|
||||
|
||||
// Update the last message with reasoning content
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoning_content: (msg.reasoning_content || "") + chunk.reasoning_content }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
|
||||
// Handle regular content - end reasoning phase when we get content
|
||||
if (chunk.content) {
|
||||
if (isReasoning) {
|
||||
// Calculate reasoning time
|
||||
const reasoningTimeMs = Date.now() - reasoningStartTime;
|
||||
isReasoning = false;
|
||||
|
||||
// Update message with reasoning time
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoningTimeMs }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
|
||||
// Update the last message (assistant) with new content
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, content: msg.content + chunk.content }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
// User cancelled, keep partial response
|
||||
// If we were still reasoning, record the time
|
||||
if (isReasoning && reasoningStartTime > 0) {
|
||||
const reasoningTimeMs = Date.now() - reasoningStartTime;
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, reasoningTimeMs }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Show error in the assistant message
|
||||
const errorMessage = error instanceof Error ? error.message : "An error occurred";
|
||||
messages = messages.map((msg, i) =>
|
||||
i === messages.length - 1
|
||||
? { ...msg, content: msg.content + `\n\n**Error:** ${errorMessage}` }
|
||||
: msg
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
isStreaming = false;
|
||||
isReasoning = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
async function editMessage(idx: number, newContent: string) {
|
||||
if (isStreaming || !$selectedModelStore) return;
|
||||
|
||||
// Update the user message at the specified index
|
||||
messages = messages.map((msg, i) =>
|
||||
i === idx ? { ...msg, content: newContent } : msg
|
||||
);
|
||||
|
||||
// Trigger a new chat request with the updated messages
|
||||
await regenerateFromIndex(idx);
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
sendMessage();
|
||||
}
|
||||
}
|
||||
|
||||
const ACCEPTED_IMAGE_FORMATS = ["image/jpeg", "image/png", "image/gif", "image/webp"];
|
||||
const MAX_IMAGE_SIZE = 20 * 1024 * 1024; // 20MB
|
||||
const MAX_IMAGES_PER_MESSAGE = 5;
|
||||
|
||||
function validateImageFile(file: File): string | null {
|
||||
if (!ACCEPTED_IMAGE_FORMATS.includes(file.type)) {
|
||||
return `Invalid file type: ${file.type}. Accepted formats: JPG, PNG, GIF, WEBP`;
|
||||
}
|
||||
if (file.size > MAX_IMAGE_SIZE) {
|
||||
return `File too large: ${(file.size / 1024 / 1024).toFixed(1)}MB. Maximum size: 20MB`;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function fileToDataUrl(file: File): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => resolve(reader.result as string);
|
||||
reader.onerror = () => reject(new Error("Failed to read file"));
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function processImageFiles(files: File[]): Promise<void> {
|
||||
imageError = null;
|
||||
|
||||
if (attachedImages.length + files.length > MAX_IMAGES_PER_MESSAGE) {
|
||||
imageError = `Maximum ${MAX_IMAGES_PER_MESSAGE} images per message`;
|
||||
return;
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
const error = validateImageFile(file);
|
||||
if (error) {
|
||||
imageError = error;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const dataUrls = await Promise.all(files.map(fileToDataUrl));
|
||||
attachedImages = [...attachedImages, ...dataUrls];
|
||||
} catch (error) {
|
||||
imageError = error instanceof Error ? error.message : "Failed to process images";
|
||||
}
|
||||
}
|
||||
|
||||
function handleImageSelect(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
if (input.files && input.files.length > 0) {
|
||||
processImageFiles(Array.from(input.files));
|
||||
}
|
||||
// Reset the input so the same file can be selected again
|
||||
input.value = "";
|
||||
}
|
||||
|
||||
function removeImage(idx: number) {
|
||||
attachedImages = attachedImages.filter((_, i) => i !== idx);
|
||||
imageError = null;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector and controls -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a model..." disabled={isStreaming} />
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => (showSettings = !showSettings)}
|
||||
title="Settings"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M8.34 1.804A1 1 0 0 1 9.32 1h1.36a1 1 0 0 1 .98.804l.295 1.473c.497.144.971.342 1.416.587l1.25-.834a1 1 0 0 1 1.262.125l.962.962a1 1 0 0 1 .125 1.262l-.834 1.25c.245.445.443.919.587 1.416l1.473.295a1 1 0 0 1 .804.98v1.36a1 1 0 0 1-.804.98l-1.473.295a6.95 6.95 0 0 1-.587 1.416l.834 1.25a1 1 0 0 1-.125 1.262l-.962.962a1 1 0 0 1-1.262.125l-1.25-.834a6.953 6.953 0 0 1-1.416.587l-.295 1.473a1 1 0 0 1-.98.804H9.32a1 1 0 0 1-.98-.804l-.295-1.473a6.957 6.957 0 0 1-1.416-.587l-1.25.834a1 1 0 0 1-1.262-.125l-.962-.962a1 1 0 0 1-.125-1.262l.834-1.25a6.957 6.957 0 0 1-.587-1.416l-1.473-.295A1 1 0 0 1 1 10.68V9.32a1 1 0 0 1 .804-.98l1.473-.295c.144-.497.342-.971.587-1.416l-.834-1.25a1 1 0 0 1 .125-1.262l.962-.962A1 1 0 0 1 5.38 3.03l1.25.834a6.957 6.957 0 0 1 1.416-.587l.294-1.473ZM13 10a3 3 0 1 1-6 0 3 3 0 0 1 6 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
<button class="btn" onclick={newChat} disabled={messages.length === 0 && !isStreaming}>
|
||||
New Chat
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Settings panel -->
|
||||
{#if showSettings}
|
||||
<div class="shrink-0 mb-4 p-4 bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
<div class="mb-4">
|
||||
<label class="block text-sm font-medium mb-1" for="system-prompt">System Prompt</label>
|
||||
<textarea
|
||||
id="system-prompt"
|
||||
class="w-full px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-card focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
placeholder="You are a helpful assistant..."
|
||||
rows="3"
|
||||
bind:value={$systemPromptStore}
|
||||
disabled={isStreaming}
|
||||
></textarea>
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-sm font-medium mb-1" for="temperature">
|
||||
Temperature: {$temperatureStore.toFixed(2)}
|
||||
</label>
|
||||
<input
|
||||
id="temperature"
|
||||
type="range"
|
||||
min="0"
|
||||
max="2"
|
||||
step="0.05"
|
||||
class="w-full"
|
||||
bind:value={$temperatureStore}
|
||||
disabled={isStreaming}
|
||||
/>
|
||||
<div class="flex justify-between text-xs text-txtsecondary mt-1">
|
||||
<span>Precise (0)</span>
|
||||
<span>Creative (2)</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to start chatting.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Messages area -->
|
||||
<div
|
||||
class="flex-1 overflow-y-auto mb-4 px-2"
|
||||
bind:this={messagesContainer}
|
||||
onscroll={handleMessagesScroll}
|
||||
>
|
||||
{#if messages.length === 0}
|
||||
<div class="h-full flex items-center justify-center text-txtsecondary">
|
||||
<p>Start a conversation by typing a message below.</p>
|
||||
</div>
|
||||
{:else}
|
||||
{#each messages as message, idx (idx)}
|
||||
<ChatMessageComponent
|
||||
role={message.role}
|
||||
content={message.content}
|
||||
reasoning_content={message.reasoning_content}
|
||||
reasoningTimeMs={message.reasoningTimeMs}
|
||||
isStreaming={isStreaming && idx === messages.length - 1 && message.role === "assistant"}
|
||||
isReasoning={isReasoning && idx === messages.length - 1 && message.role === "assistant"}
|
||||
onEdit={message.role === "user" ? (newContent) => editMessage(idx, newContent) : undefined}
|
||||
onRegenerate={message.role === "assistant" && idx > 0 && messages[idx - 1].role === "user"
|
||||
? () => regenerateFromIndex(idx - 1)
|
||||
: undefined}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Input area -->
|
||||
<div class="shrink-0">
|
||||
<!-- Image preview strip -->
|
||||
{#if attachedImages.length > 0}
|
||||
<div class="mb-2 flex flex-wrap gap-2">
|
||||
{#each attachedImages as imageUrl, idx (idx)}
|
||||
<div class="relative group">
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Attached image {idx + 1}"
|
||||
class="w-20 h-20 object-cover rounded border border-gray-200 dark:border-white/10"
|
||||
/>
|
||||
<button
|
||||
class="absolute -top-2 -right-2 bg-red-500 text-white rounded-full w-6 h-6 flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onclick={() => removeImage(idx)}
|
||||
title="Remove image"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Error message -->
|
||||
{#if imageError}
|
||||
<div class="mb-2 p-2 bg-red-100 dark:bg-red-900/20 text-red-700 dark:text-red-400 rounded text-sm">
|
||||
{imageError}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="flex gap-2">
|
||||
<!-- Hidden file input -->
|
||||
<input
|
||||
type="file"
|
||||
accept=".jpg,.jpeg,.png,.gif,.webp"
|
||||
multiple
|
||||
class="hidden"
|
||||
bind:this={fileInput}
|
||||
onchange={handleImageSelect}
|
||||
/>
|
||||
|
||||
<ExpandableTextarea
|
||||
bind:value={userInput}
|
||||
placeholder="Type a message..."
|
||||
rows={3}
|
||||
onkeydown={handleKeyDown}
|
||||
disabled={isStreaming || !$selectedModelStore}
|
||||
/>
|
||||
<div class="flex flex-col gap-2">
|
||||
{#if isStreaming}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancelStreaming}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn"
|
||||
onclick={() => fileInput?.click()}
|
||||
disabled={isStreaming || !$selectedModelStore}
|
||||
title="Attach image"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5">
|
||||
<path fill-rule="evenodd" d="M1 5.25A2.25 2.25 0 0 1 3.25 3h13.5A2.25 2.25 0 0 1 19 5.25v9.5A2.25 2.25 0 0 1 16.75 17H3.25A2.25 2.25 0 0 1 1 14.75v-9.5Zm1.5 5.81v3.69c0 .414.336.75.75.75h13.5a.75.75 0 0 0 .75-.75v-2.69l-2.22-2.219a.75.75 0 0 0-1.06 0l-1.91 1.909.47.47a.75.75 0 1 1-1.06 1.06L6.53 8.091a.75.75 0 0 0-1.06 0l-2.97 2.97ZM12 7a1 1 0 1 1-2 0 1 1 0 0 1 2 0Z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={sendMessage}
|
||||
disabled={(!userInput.trim() && attachedImages.length === 0) || !$selectedModelStore}
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,467 @@
|
||||
<script lang="ts">
|
||||
import { renderMarkdown, escapeHtml, renderStreamingMarkdown, createStreamingCache } from "../../lib/markdown";
|
||||
import type { RenderedBlock } from "../../lib/markdown";
|
||||
import { Copy, Check, Pencil, X, Save, RefreshCw, ChevronDown, ChevronRight, Brain, Code } from "lucide-svelte";
|
||||
import { getTextContent, getImageUrls } from "../../lib/types";
|
||||
import type { ContentPart } from "../../lib/types";
|
||||
|
||||
interface Props {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string | ContentPart[];
|
||||
reasoning_content?: string;
|
||||
reasoningTimeMs?: number;
|
||||
isStreaming?: boolean;
|
||||
isReasoning?: boolean;
|
||||
onEdit?: (newContent: string) => void;
|
||||
onRegenerate?: () => void;
|
||||
}
|
||||
|
||||
let { role, content, reasoning_content = "", reasoningTimeMs = 0, isStreaming = false, isReasoning = false, onEdit, onRegenerate }: Props = $props();
|
||||
|
||||
let textContent = $derived(getTextContent(content));
|
||||
let imageUrls = $derived(getImageUrls(content));
|
||||
let hasImages = $derived(imageUrls.length > 0);
|
||||
let canEdit = $derived(onEdit !== undefined && !hasImages);
|
||||
|
||||
let streamingCache = createStreamingCache();
|
||||
let renderedParts = $derived.by(() => {
|
||||
if (role !== "assistant") {
|
||||
return { blocks: [{ id: -1, html: escapeHtml(textContent).replace(/\n/g, '<br>') }] as RenderedBlock[], pendingHtml: "" };
|
||||
}
|
||||
if (!isStreaming) {
|
||||
streamingCache = createStreamingCache();
|
||||
return { blocks: [{ id: -1, html: renderMarkdown(textContent) }] as RenderedBlock[], pendingHtml: "" };
|
||||
}
|
||||
return renderStreamingMarkdown(textContent, streamingCache);
|
||||
});
|
||||
let copied = $state(false);
|
||||
let showRaw = $state(false);
|
||||
let isEditing = $state(false);
|
||||
let editContent = $state("");
|
||||
let showReasoning = $state(false);
|
||||
let modalImageUrl = $state<string | null>(null);
|
||||
|
||||
function formatDuration(ms: number): string {
|
||||
if (ms < 1000) {
|
||||
return `${ms.toFixed(0)}ms`;
|
||||
}
|
||||
return `${(ms / 1000).toFixed(1)}s`;
|
||||
}
|
||||
|
||||
async function copyToClipboard() {
|
||||
try {
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(textContent);
|
||||
} else {
|
||||
// Fallback for non-secure contexts (HTTP)
|
||||
const textarea = document.createElement("textarea");
|
||||
textarea.value = textContent;
|
||||
textarea.style.position = "fixed";
|
||||
textarea.style.left = "-9999px";
|
||||
document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
document.execCommand("copy");
|
||||
document.body.removeChild(textarea);
|
||||
}
|
||||
copied = true;
|
||||
setTimeout(() => (copied = false), 2000);
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
}
|
||||
}
|
||||
|
||||
function startEdit() {
|
||||
editContent = textContent;
|
||||
isEditing = true;
|
||||
}
|
||||
|
||||
function cancelEdit() {
|
||||
isEditing = false;
|
||||
editContent = "";
|
||||
}
|
||||
|
||||
function saveEdit() {
|
||||
if (onEdit && editContent.trim() !== textContent) {
|
||||
onEdit(editContent.trim());
|
||||
}
|
||||
isEditing = false;
|
||||
editContent = "";
|
||||
}
|
||||
|
||||
function openModal(imageUrl: string) {
|
||||
modalImageUrl = imageUrl;
|
||||
document.body.style.overflow = "hidden";
|
||||
}
|
||||
|
||||
function closeModal(event?: MouseEvent) {
|
||||
// Only close if clicking the background, not the image
|
||||
if (event && event.target !== event.currentTarget) {
|
||||
return;
|
||||
}
|
||||
modalImageUrl = null;
|
||||
document.body.style.overflow = "";
|
||||
}
|
||||
|
||||
function handleModalKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Escape") {
|
||||
closeModal();
|
||||
}
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
saveEdit();
|
||||
} else if (event.key === "Escape") {
|
||||
cancelEdit();
|
||||
}
|
||||
}
|
||||
|
||||
const COPY_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>`;
|
||||
const CHECK_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M20 6 9 17l-5-5"/></svg>`;
|
||||
|
||||
function codeBlockCopy(node: HTMLElement) {
|
||||
function attachButtons() {
|
||||
node.querySelectorAll<HTMLPreElement>('pre:not([data-copy-btn])').forEach(pre => {
|
||||
pre.setAttribute('data-copy-btn', 'true');
|
||||
const btn = document.createElement('button');
|
||||
btn.className = 'code-copy-btn';
|
||||
btn.title = 'Copy code';
|
||||
btn.innerHTML = COPY_SVG;
|
||||
btn.addEventListener('click', async () => {
|
||||
const text = pre.querySelector('code')?.textContent ?? pre.textContent ?? '';
|
||||
try {
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} else {
|
||||
const ta = document.createElement('textarea');
|
||||
ta.value = text;
|
||||
ta.style.cssText = 'position:fixed;left:-9999px';
|
||||
document.body.appendChild(ta);
|
||||
ta.select();
|
||||
document.execCommand('copy');
|
||||
document.body.removeChild(ta);
|
||||
}
|
||||
btn.innerHTML = CHECK_SVG;
|
||||
btn.classList.add('copied');
|
||||
setTimeout(() => { btn.innerHTML = COPY_SVG; btn.classList.remove('copied'); }, 2000);
|
||||
} catch (e) {
|
||||
console.error('copy failed', e);
|
||||
}
|
||||
});
|
||||
pre.appendChild(btn);
|
||||
});
|
||||
}
|
||||
attachButtons();
|
||||
const mo = new MutationObserver(attachButtons);
|
||||
mo.observe(node, { childList: true, subtree: true });
|
||||
return { destroy: () => mo.disconnect() };
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex {role === 'user' ? 'justify-end' : 'justify-start'} mb-4">
|
||||
<div
|
||||
class="relative group rounded-lg px-4 py-2 {role === 'user'
|
||||
? 'max-w-[85%] bg-primary text-btn-primary-text'
|
||||
: 'w-full sm:w-4/5 bg-surface border border-gray-200 dark:border-white/10'}"
|
||||
>
|
||||
{#if role === "assistant"}
|
||||
{#if reasoning_content || isReasoning}
|
||||
<div class="mb-3 border border-gray-200 dark:border-white/10 rounded overflow-hidden">
|
||||
<button
|
||||
class="w-full flex items-center gap-2 px-3 py-2 bg-gray-50 dark:bg-white/5 hover:bg-gray-100 dark:hover:bg-white/10 transition-colors text-sm"
|
||||
onclick={() => showReasoning = !showReasoning}
|
||||
>
|
||||
{#if showReasoning}
|
||||
<ChevronDown class="w-4 h-4" />
|
||||
{:else}
|
||||
<ChevronRight class="w-4 h-4" />
|
||||
{/if}
|
||||
<Brain class="w-4 h-4" />
|
||||
<span class="font-medium">Reasoning</span>
|
||||
<span class="text-txtsecondary ml-2">
|
||||
({reasoning_content.length} chars{#if !isReasoning && reasoningTimeMs > 0}, {formatDuration(reasoningTimeMs)}{/if})
|
||||
</span>
|
||||
{#if isReasoning}
|
||||
<span class="ml-auto flex items-center gap-1 text-txtsecondary">
|
||||
<span class="w-1.5 h-1.5 bg-primary rounded-full animate-pulse"></span>
|
||||
reasoning...
|
||||
</span>
|
||||
{/if}
|
||||
</button>
|
||||
{#if showReasoning}
|
||||
<div class="px-3 py-2 bg-gray-50/50 dark:bg-white/[0.02] text-sm text-txtsecondary whitespace-pre-wrap font-mono">
|
||||
{reasoning_content}{#if isReasoning}<span class="inline-block w-1.5 h-4 bg-current animate-pulse ml-0.5"></span>{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{#if hasImages}
|
||||
<div class="mb-3 flex flex-wrap gap-2">
|
||||
{#each imageUrls as imageUrl, idx (idx)}
|
||||
<button
|
||||
onclick={() => openModal(imageUrl)}
|
||||
class="cursor-pointer rounded border border-gray-200 dark:border-white/10 hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Image {idx + 1}"
|
||||
class="max-h-64 rounded"
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{#if showRaw}
|
||||
<div class="whitespace-pre-wrap font-mono text-sm">{textContent}</div>
|
||||
{:else}
|
||||
<div class="prose prose-sm dark:prose-invert max-w-none" use:codeBlockCopy>
|
||||
{#each renderedParts.blocks as block (block.id)}
|
||||
{@html block.html}
|
||||
{/each}
|
||||
{@html renderedParts.pendingHtml}
|
||||
{#if isStreaming && !isReasoning}
|
||||
<span class="inline-block w-2 h-4 bg-current animate-pulse ml-0.5"></span>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{#if !isStreaming}
|
||||
<div class="flex gap-1 mt-2 pt-1 border-t border-gray-200 dark:border-white/10">
|
||||
{#if onRegenerate}
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 text-txtsecondary"
|
||||
onclick={onRegenerate}
|
||||
title="Regenerate response"
|
||||
>
|
||||
<RefreshCw class="w-4 h-4" />
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 text-txtsecondary"
|
||||
onclick={copyToClipboard}
|
||||
title={copied ? "Copied!" : "Copy to clipboard"}
|
||||
>
|
||||
{#if copied}
|
||||
<Check class="w-4 h-4 text-green-500" />
|
||||
{:else}
|
||||
<Copy class="w-4 h-4" />
|
||||
{/if}
|
||||
</button>
|
||||
<button
|
||||
class="p-1 rounded hover:bg-black/10 dark:hover:bg-white/10 {showRaw ? 'text-primary' : 'text-txtsecondary'}"
|
||||
onclick={() => showRaw = !showRaw}
|
||||
title={showRaw ? "Show rendered" : "Show raw"}
|
||||
>
|
||||
<Code class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
{:else}
|
||||
{#if isEditing}
|
||||
<div class="flex flex-col gap-2 min-w-[300px]">
|
||||
<textarea
|
||||
class="w-full px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface text-txtmain focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
rows="3"
|
||||
bind:value={editContent}
|
||||
onkeydown={handleKeyDown}
|
||||
></textarea>
|
||||
<div class="flex justify-end gap-2">
|
||||
<button
|
||||
class="p-1.5 rounded hover:bg-white/20"
|
||||
onclick={cancelEdit}
|
||||
title="Cancel"
|
||||
>
|
||||
<X class="w-4 h-4" />
|
||||
</button>
|
||||
<button
|
||||
class="p-1.5 rounded hover:bg-white/20"
|
||||
onclick={saveEdit}
|
||||
title="Save"
|
||||
>
|
||||
<Save class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
{#if hasImages}
|
||||
<div class="mb-2 flex flex-wrap gap-2">
|
||||
{#each imageUrls as imageUrl, idx (idx)}
|
||||
<button
|
||||
onclick={() => openModal(imageUrl)}
|
||||
class="cursor-pointer rounded border border-white/20 hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Image {idx + 1}"
|
||||
class="max-w-[200px] rounded"
|
||||
/>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="whitespace-pre-wrap pr-8">{textContent}</div>
|
||||
{#if canEdit}
|
||||
<button
|
||||
class="absolute top-2 right-2 p-1.5 rounded-lg opacity-0 group-hover:opacity-100 transition-opacity bg-white/20 hover:bg-white/30 shadow-sm"
|
||||
onclick={startEdit}
|
||||
title="Edit message"
|
||||
>
|
||||
<Pencil class="w-4 h-4" />
|
||||
</button>
|
||||
{/if}
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Full-size image modal -->
|
||||
{#if modalImageUrl}
|
||||
<div
|
||||
class="fixed inset-0 z-50 flex items-center justify-center bg-black/80 p-4"
|
||||
onclick={(e) => closeModal(e)}
|
||||
onkeydown={handleModalKeyDown}
|
||||
role="button"
|
||||
tabindex="-1"
|
||||
>
|
||||
<button
|
||||
class="absolute top-4 right-4 p-2 rounded-lg bg-white/10 hover:bg-white/20 text-white transition-colors"
|
||||
onclick={() => closeModal()}
|
||||
title="Close"
|
||||
>
|
||||
<X class="w-6 h-6" />
|
||||
</button>
|
||||
<img
|
||||
src={modalImageUrl}
|
||||
alt=""
|
||||
class="max-w-full max-h-full rounded pointer-events-none"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.prose :global(pre) {
|
||||
position: relative;
|
||||
background-color: var(--color-surface);
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
border-radius: 0.375rem;
|
||||
padding: 0.75rem;
|
||||
padding-right: 2.5rem;
|
||||
overflow-x: auto;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(.code-copy-btn) {
|
||||
position: absolute;
|
||||
top: 0.375rem;
|
||||
right: 0.375rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-border);
|
||||
background: var(--color-surface);
|
||||
color: var(--color-txtsecondary);
|
||||
cursor: pointer;
|
||||
transition: background-color 0.15s;
|
||||
line-height: 0;
|
||||
}
|
||||
|
||||
.prose :global(.code-copy-btn:hover) {
|
||||
background: var(--color-secondary);
|
||||
}
|
||||
|
||||
.prose :global(.code-copy-btn.copied) {
|
||||
color: var(--color-success);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.prose :global(code) {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
.prose :global(pre code) {
|
||||
background: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.prose :global(code:not(pre code)) {
|
||||
background-color: var(--color-surface);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
}
|
||||
|
||||
.prose :global(p) {
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(p:first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.prose :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.prose :global(ul),
|
||||
.prose :global(ol) {
|
||||
margin: 0.5rem 0;
|
||||
padding-left: 1.5rem;
|
||||
}
|
||||
|
||||
.prose :global(li) {
|
||||
margin: 0.25rem 0;
|
||||
}
|
||||
|
||||
.prose :global(h1),
|
||||
.prose :global(h2),
|
||||
.prose :global(h3),
|
||||
.prose :global(h4) {
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.prose :global(h1:first-child),
|
||||
.prose :global(h2:first-child),
|
||||
.prose :global(h3:first-child),
|
||||
.prose :global(h4:first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.prose :global(blockquote) {
|
||||
border-left: 3px solid var(--color-primary);
|
||||
padding-left: 1rem;
|
||||
margin: 0.5rem 0;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.prose :global(a) {
|
||||
color: var(--color-primary);
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.prose :global(table) {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.prose :global(th),
|
||||
.prose :global(td) {
|
||||
border: 1px solid var(--color-border, rgba(128, 128, 128, 0.2));
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.prose :global(th) {
|
||||
background-color: var(--color-surface);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
/* Highlight.js theme overrides for dark mode */
|
||||
:global(.dark) .prose :global(.hljs) {
|
||||
background: transparent;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,121 @@
|
||||
<script lang="ts">
|
||||
import { untrack } from "svelte";
|
||||
import { Maximize2, X } from "lucide-svelte";
|
||||
|
||||
interface Props {
|
||||
value: string;
|
||||
placeholder?: string;
|
||||
rows?: number;
|
||||
disabled?: boolean;
|
||||
onkeydown?: (event: KeyboardEvent) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
value = $bindable(),
|
||||
placeholder = "",
|
||||
rows = 3,
|
||||
disabled = false,
|
||||
onkeydown,
|
||||
}: Props = $props();
|
||||
|
||||
let isExpanded = $state(false);
|
||||
let expandedValue = $state("");
|
||||
let expandedTextarea: HTMLTextAreaElement | undefined = $state();
|
||||
|
||||
function openExpanded() {
|
||||
expandedValue = value;
|
||||
isExpanded = true;
|
||||
}
|
||||
|
||||
function closeExpanded() {
|
||||
isExpanded = false;
|
||||
}
|
||||
|
||||
function saveExpanded() {
|
||||
value = expandedValue;
|
||||
isExpanded = false;
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Escape") {
|
||||
closeExpanded();
|
||||
}
|
||||
}
|
||||
|
||||
// Focus the textarea when expanded view opens
|
||||
$effect(() => {
|
||||
if (isExpanded && expandedTextarea) {
|
||||
expandedTextarea.focus();
|
||||
const len = untrack(() => expandedValue.length);
|
||||
expandedTextarea.setSelectionRange(len, len);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex-1 relative group flex items-stretch min-h-0">
|
||||
<textarea
|
||||
class="w-full px-3 py-2 pr-10 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-inset focus:ring-primary resize-none"
|
||||
{placeholder}
|
||||
{rows}
|
||||
bind:value
|
||||
{onkeydown}
|
||||
{disabled}
|
||||
></textarea>
|
||||
<button
|
||||
class="absolute top-2 right-2 p-1.5 rounded-lg opacity-60 md:opacity-0 group-hover:opacity-100 transition-opacity bg-surface/90 hover:bg-surface border border-gray-200 dark:border-white/10 shadow-sm"
|
||||
onclick={openExpanded}
|
||||
title="Expand to edit"
|
||||
type="button"
|
||||
{disabled}
|
||||
>
|
||||
<Maximize2 class="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{#if isExpanded}
|
||||
<div class="fixed inset-0 z-50 flex items-center justify-center bg-black/50 p-4">
|
||||
<div class="w-full max-w-4xl h-[80vh] flex flex-col bg-surface rounded-lg shadow-xl border border-gray-200 dark:border-white/10">
|
||||
<!-- Header -->
|
||||
<div class="flex justify-between items-center p-4 border-b border-gray-200 dark:border-white/10">
|
||||
<h3 class="font-medium">Edit Text</h3>
|
||||
<button
|
||||
class="p-1.5 rounded-lg hover:bg-gray-100 dark:hover:bg-white/10"
|
||||
onclick={closeExpanded}
|
||||
title="Close"
|
||||
type="button"
|
||||
>
|
||||
<X class="w-5 h-5" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Textarea -->
|
||||
<div class="flex-1 p-4">
|
||||
<textarea
|
||||
bind:this={expandedTextarea}
|
||||
class="w-full h-full px-4 py-3 rounded border border-gray-200 dark:border-white/10 bg-card focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
placeholder={placeholder}
|
||||
bind:value={expandedValue}
|
||||
onkeydown={handleKeyDown}
|
||||
></textarea>
|
||||
</div>
|
||||
|
||||
<!-- Footer -->
|
||||
<div class="flex justify-end gap-2 p-4 border-t border-gray-200 dark:border-white/10">
|
||||
<button
|
||||
class="btn"
|
||||
onclick={closeExpanded}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={saveExpanded}
|
||||
type="button"
|
||||
>
|
||||
Done
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -0,0 +1,521 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { generateImage } from "../../lib/imageApi";
|
||||
import { generateSdImage, fetchSdLoras } from "../../lib/sdApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
import ExpandableTextarea from "./ExpandableTextarea.svelte";
|
||||
import type { ImageApiMode, SdApiLora, SdApiLoraRef } from "../../lib/types";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-image-model", "");
|
||||
const selectedSizeStore = persistentStore<string>("playground-image-size", "1024x1024");
|
||||
const apiModeStore = persistentStore<ImageApiMode>("playground-image-api-mode", "openai");
|
||||
|
||||
// SDAPI persistent settings
|
||||
const sdNegativePromptStore = persistentStore<string>("playground-sdapi-negative-prompt", "");
|
||||
const sdStepsStore = persistentStore<number>("playground-sdapi-steps", 20);
|
||||
const sdCfgScaleStore = persistentStore<number>("playground-sdapi-cfg-scale", 7);
|
||||
const sdSeedStore = persistentStore<number>("playground-sdapi-seed", -1);
|
||||
const sdSamplerStore = persistentStore<string>("playground-sdapi-sampler", "");
|
||||
const sdSchedulerStore = persistentStore<string>("playground-sdapi-scheduler", "");
|
||||
const sdBatchSizeStore = persistentStore<number>("playground-sdapi-batch-size", 1);
|
||||
|
||||
let prompt = $state("");
|
||||
let isGenerating = $state(false);
|
||||
let generatedImages = $state<string[]>([]);
|
||||
let error = $state<string | null>(null);
|
||||
let abortController = $state<AbortController | null>(null);
|
||||
let showFullscreen = $state(false);
|
||||
let fullscreenIndex = $state(0);
|
||||
let showSettings = $state(false);
|
||||
|
||||
// SDAPI lora state
|
||||
let availableLoras = $state<SdApiLora[]>([]);
|
||||
let selectedLoras = $state<SdApiLoraRef[]>([]);
|
||||
let isLoadingLoras = $state(false);
|
||||
let lorasLoaded = $state(false);
|
||||
let loraError = $state<string | null>(null);
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
let isSdapi = $derived($apiModeStore === "sdapi");
|
||||
|
||||
$effect(() => {
|
||||
playgroundStores.imageGenerating.set(isGenerating);
|
||||
});
|
||||
|
||||
async function loadLoras() {
|
||||
if (!$selectedModelStore || isLoadingLoras) return;
|
||||
isLoadingLoras = true;
|
||||
loraError = null;
|
||||
try {
|
||||
const loras = await fetchSdLoras($selectedModelStore);
|
||||
availableLoras = loras;
|
||||
lorasLoaded = true;
|
||||
} catch (err) {
|
||||
availableLoras = [];
|
||||
loraError = err instanceof Error ? err.message : "Failed to load LoRAs";
|
||||
lorasLoaded = false;
|
||||
} finally {
|
||||
isLoadingLoras = false;
|
||||
}
|
||||
}
|
||||
|
||||
function addLora(event: Event) {
|
||||
const select = event.target as HTMLSelectElement;
|
||||
const path = select.value;
|
||||
if (!path) return;
|
||||
|
||||
const lora = availableLoras.find((l) => l.path === path);
|
||||
if (lora && !selectedLoras.some((l) => l.path === path)) {
|
||||
selectedLoras = [...selectedLoras, { path: lora.path, multiplier: 1.0 }];
|
||||
}
|
||||
select.value = "";
|
||||
}
|
||||
|
||||
function removeLora(path: string) {
|
||||
selectedLoras = selectedLoras.filter((l) => l.path !== path);
|
||||
}
|
||||
|
||||
function updateLoraMultiplier(path: string, multiplier: number) {
|
||||
selectedLoras = selectedLoras.map((l) =>
|
||||
l.path === path ? { ...l, multiplier } : l
|
||||
);
|
||||
}
|
||||
|
||||
function getLoraName(path: string): string {
|
||||
return availableLoras.find((l) => l.path === path)?.name ?? path;
|
||||
}
|
||||
|
||||
async function generate() {
|
||||
const trimmedPrompt = prompt.trim();
|
||||
if (!trimmedPrompt || !$selectedModelStore || isGenerating) return;
|
||||
|
||||
isGenerating = true;
|
||||
error = null;
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
if (isSdapi) {
|
||||
const [w, h] = $selectedSizeStore.split("x").map(Number);
|
||||
const request = {
|
||||
model: $selectedModelStore,
|
||||
prompt: trimmedPrompt,
|
||||
negative_prompt: $sdNegativePromptStore || undefined,
|
||||
width: w,
|
||||
height: h,
|
||||
steps: $sdStepsStore,
|
||||
cfg_scale: $sdCfgScaleStore,
|
||||
seed: $sdSeedStore,
|
||||
batch_size: $sdBatchSizeStore,
|
||||
sampler_name: $sdSamplerStore || undefined,
|
||||
scheduler: $sdSchedulerStore || undefined,
|
||||
lora: selectedLoras.length > 0 ? selectedLoras : undefined,
|
||||
};
|
||||
|
||||
const response = await generateSdImage(request, abortController.signal);
|
||||
if (response.images && response.images.length > 0) {
|
||||
generatedImages = response.images.map(
|
||||
(img) => `data:image/png;base64,${img}`
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const response = await generateImage(
|
||||
$selectedModelStore,
|
||||
trimmedPrompt,
|
||||
$selectedSizeStore,
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
if (response.data && response.data.length > 0) {
|
||||
const imageData = response.data[0];
|
||||
if (imageData.b64_json) {
|
||||
generatedImages = [`data:image/png;base64,${imageData.b64_json}`];
|
||||
} else if (imageData.url) {
|
||||
generatedImages = [imageData.url];
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// User cancelled
|
||||
} else {
|
||||
error = err instanceof Error ? err.message : "An error occurred";
|
||||
}
|
||||
} finally {
|
||||
isGenerating = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function cancelGeneration() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function clearImage() {
|
||||
generatedImages = [];
|
||||
error = null;
|
||||
prompt = "";
|
||||
}
|
||||
|
||||
function downloadImage(index: number = 0) {
|
||||
const img = generatedImages[index];
|
||||
if (!img) return;
|
||||
|
||||
const link = document.createElement("a");
|
||||
link.href = img;
|
||||
link.download = `generated-image-${Date.now()}-${index}.png`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
}
|
||||
|
||||
function openFullscreen(index: number = 0) {
|
||||
fullscreenIndex = index;
|
||||
showFullscreen = true;
|
||||
}
|
||||
|
||||
function closeFullscreen(event?: MouseEvent) {
|
||||
if (event && event.target !== event.currentTarget) {
|
||||
return;
|
||||
}
|
||||
showFullscreen = false;
|
||||
}
|
||||
|
||||
function handleKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
generate();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Model selector and mode toggle -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} />
|
||||
|
||||
<select
|
||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$apiModeStore}
|
||||
disabled={isGenerating}
|
||||
>
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="sdapi">SDAPI</option>
|
||||
</select>
|
||||
|
||||
<select
|
||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$selectedSizeStore}
|
||||
disabled={isGenerating}
|
||||
>
|
||||
<optgroup label="Square">
|
||||
<option value="512x512">512x512</option>
|
||||
<option value="1024x1024">1024x1024</option>
|
||||
</optgroup>
|
||||
<optgroup label="Landscape">
|
||||
<option value="1024x768">1024x768 (4:3)</option>
|
||||
<option value="1280x720">1280x720 (16:9)</option>
|
||||
<option value="1792x1024">1792x1024 (SDXL)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Portrait">
|
||||
<option value="768x1024">768x1024 (3:4)</option>
|
||||
<option value="720x1280">720x1280 (9:16)</option>
|
||||
<option value="1024x1792">1024x1792 (SDXL)</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
|
||||
{#if isSdapi}
|
||||
<button
|
||||
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface hover:bg-secondary-hover transition-colors"
|
||||
onclick={() => showSettings = !showSettings}
|
||||
>
|
||||
{showSettings ? "Hide Settings" : "Settings"}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- SDAPI Settings Panel -->
|
||||
{#if isSdapi && showSettings}
|
||||
<div class="shrink-0 mb-4 p-4 rounded border border-gray-200 dark:border-white/10 bg-surface">
|
||||
<div class="grid grid-cols-2 md:grid-cols-4 gap-3 mb-3">
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Steps</span>
|
||||
<input
|
||||
type="number"
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdStepsStore}
|
||||
min="1"
|
||||
max="150"
|
||||
/>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">CFG Scale</span>
|
||||
<input
|
||||
type="number"
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdCfgScaleStore}
|
||||
min="1"
|
||||
max="30"
|
||||
step="0.5"
|
||||
/>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Seed (-1 = random)</span>
|
||||
<input
|
||||
type="number"
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdSeedStore}
|
||||
min="-1"
|
||||
/>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Batch Size</span>
|
||||
<input
|
||||
type="number"
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdBatchSizeStore}
|
||||
min="1"
|
||||
max="8"
|
||||
/>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Sampler</span>
|
||||
<select
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdSamplerStore}
|
||||
>
|
||||
<option value="">Default</option>
|
||||
<option value="euler_a">euler_a</option>
|
||||
<option value="euler">euler</option>
|
||||
<option value="heun">heun</option>
|
||||
<option value="dpm2">dpm2</option>
|
||||
<option value="dpmpp2s_a">dpmpp2s_a</option>
|
||||
<option value="dpmpp2m">dpmpp2m</option>
|
||||
<option value="dpmpp2mv2">dpmpp2mv2</option>
|
||||
<option value="ipndm">ipndm</option>
|
||||
<option value="ipndm_v">ipndm_v</option>
|
||||
<option value="lcm">lcm</option>
|
||||
<option value="ddim_trailing">ddim_trailing</option>
|
||||
<option value="tcd">tcd</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="flex flex-col gap-1">
|
||||
<span class="text-xs text-txtsecondary">Scheduler</span>
|
||||
<select
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value={$sdSchedulerStore}
|
||||
>
|
||||
<option value="">Auto for model</option>
|
||||
<option value="discrete">discrete</option>
|
||||
<option value="karras">karras</option>
|
||||
<option value="exponential">exponential</option>
|
||||
<option value="ays">ays</option>
|
||||
<option value="gits">gits</option>
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label class="flex flex-col gap-1 mb-3">
|
||||
<span class="text-xs text-txtsecondary">Negative Prompt</span>
|
||||
<textarea
|
||||
class="px-2 py-1 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary resize-y text-sm"
|
||||
bind:value={$sdNegativePromptStore}
|
||||
rows="2"
|
||||
placeholder="Elements to avoid..."
|
||||
></textarea>
|
||||
</label>
|
||||
|
||||
<!-- LoRA Selection -->
|
||||
<div>
|
||||
<span class="text-xs text-txtsecondary block mb-1">LoRAs</span>
|
||||
<div class="flex items-center gap-2 mb-2">
|
||||
<button
|
||||
class="px-3 py-1.5 text-sm rounded border border-gray-200 dark:border-white/10 bg-surface hover:bg-secondary-hover transition-colors disabled:opacity-50"
|
||||
onclick={loadLoras}
|
||||
disabled={!$selectedModelStore || isLoadingLoras}
|
||||
>
|
||||
{isLoadingLoras ? "Loading..." : lorasLoaded ? "Reload LoRAs" : "Load LoRAs"}
|
||||
</button>
|
||||
{#if lorasLoaded && availableLoras.length > 0}
|
||||
<select
|
||||
class="flex-1 px-2 py-1.5 text-sm rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
onchange={addLora}
|
||||
>
|
||||
<option value="">Add a LoRA...</option>
|
||||
{#each availableLoras.filter((l) => !selectedLoras.some((s) => s.path === l.path)) as lora}
|
||||
<option value={lora.path}>{lora.name}</option>
|
||||
{/each}
|
||||
</select>
|
||||
{/if}
|
||||
</div>
|
||||
{#if loraError}
|
||||
<p class="text-xs text-red-500 mb-1">{loraError}</p>
|
||||
{/if}
|
||||
{#if lorasLoaded && availableLoras.length === 0}
|
||||
<p class="text-xs text-txtsecondary">No LoRAs available</p>
|
||||
{/if}
|
||||
{#if selectedLoras.length > 0}
|
||||
<div class="flex flex-col gap-1.5">
|
||||
{#each selectedLoras as lora}
|
||||
<div class="flex items-center gap-2 text-sm">
|
||||
<span class="flex-1 truncate">{getLoraName(lora.path)}</span>
|
||||
<input
|
||||
type="number"
|
||||
class="w-20 px-1.5 py-1 text-xs rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-1 focus:ring-primary"
|
||||
value={lora.multiplier}
|
||||
oninput={(e) => updateLoraMultiplier(lora.path, parseFloat((e.target as HTMLInputElement).value) || 1)}
|
||||
min="0"
|
||||
max="2"
|
||||
step="0.1"
|
||||
/>
|
||||
<button
|
||||
class="px-1.5 py-0.5 text-xs rounded border border-gray-200 dark:border-white/10 hover:bg-red-500 hover:text-white hover:border-red-500 transition-colors"
|
||||
onclick={() => removeLora(lora.path)}
|
||||
aria-label="Remove LoRA"
|
||||
>
|
||||
x
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Empty state for no models configured -->
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to generate images.</p>
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Image display area -->
|
||||
<div class="flex-1 overflow-auto mb-4 flex items-center justify-center bg-surface border border-gray-200 dark:border-white/10 rounded">
|
||||
{#if isGenerating}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<div class="inline-block w-8 h-8 border-4 border-primary border-t-transparent rounded-full animate-spin mb-2"></div>
|
||||
<p>Generating image...</p>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="text-center text-red-500 p-4">
|
||||
<p class="font-medium">Error</p>
|
||||
<p class="text-sm mt-1">{error}</p>
|
||||
</div>
|
||||
{:else if generatedImages.length > 1}
|
||||
<!-- Grid for multiple images (batch) -->
|
||||
<div class="grid grid-cols-2 gap-2 p-2 w-full h-full overflow-auto">
|
||||
{#each generatedImages as img, i}
|
||||
<div class="relative flex items-center justify-center">
|
||||
<button
|
||||
class="p-0 border-0 bg-transparent cursor-pointer"
|
||||
onclick={() => openFullscreen(i)}
|
||||
aria-label="View fullscreen"
|
||||
>
|
||||
<img
|
||||
src={img}
|
||||
alt="AI generated content {i + 1}"
|
||||
class="max-w-full max-h-full object-contain hover:opacity-90 transition-opacity"
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
class="absolute bottom-2 right-2 p-1.5 bg-black/60 hover:bg-black/80 text-white rounded-full transition-colors"
|
||||
onclick={(e) => { e.stopPropagation(); downloadImage(i); }}
|
||||
aria-label="Download image"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{:else if generatedImages.length === 1}
|
||||
<div class="relative max-w-full max-h-full flex items-center justify-center">
|
||||
<button
|
||||
class="p-0 border-0 bg-transparent cursor-pointer"
|
||||
onclick={() => openFullscreen(0)}
|
||||
aria-label="View fullscreen"
|
||||
>
|
||||
<img
|
||||
src={generatedImages[0]}
|
||||
alt="AI generated content"
|
||||
class="max-w-full max-h-full object-contain hover:opacity-90 transition-opacity"
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
class="absolute bottom-2 right-2 p-2 bg-black/60 hover:bg-black/80 text-white rounded-full transition-colors"
|
||||
onclick={(e) => { e.stopPropagation(); downloadImage(0); }}
|
||||
aria-label="Download image"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="text-center text-txtsecondary">
|
||||
<p>Enter a prompt below to generate an image</p>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Prompt input area -->
|
||||
<div class="shrink-0 flex flex-col md:flex-row gap-2">
|
||||
<ExpandableTextarea
|
||||
bind:value={prompt}
|
||||
placeholder="Describe the image you want to generate..."
|
||||
rows={3}
|
||||
onkeydown={handleKeyDown}
|
||||
disabled={isGenerating || !$selectedModelStore}
|
||||
/>
|
||||
<div class="flex flex-row md:flex-col gap-2">
|
||||
{#if isGenerating}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white flex-1 md:flex-none" onclick={cancelGeneration}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90 flex-1 md:flex-none"
|
||||
onclick={generate}
|
||||
disabled={!prompt.trim() || !$selectedModelStore}
|
||||
>
|
||||
Generate
|
||||
</button>
|
||||
<button
|
||||
class="btn flex-1 md:flex-none"
|
||||
onclick={clearImage}
|
||||
disabled={generatedImages.length === 0 && !error && !prompt.trim()}
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Fullscreen dialog -->
|
||||
{#if showFullscreen && generatedImages[fullscreenIndex]}
|
||||
<div
|
||||
class="fixed inset-0 bg-black/90 z-50 flex items-center justify-center p-4"
|
||||
onclick={(e) => closeFullscreen(e)}
|
||||
onkeydown={(e) => e.key === 'Escape' && closeFullscreen()}
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
tabindex="-1"
|
||||
>
|
||||
<button
|
||||
class="absolute top-4 right-4 text-white hover:text-gray-300 text-2xl w-10 h-10 flex items-center justify-center rounded-full hover:bg-white/10 transition-colors"
|
||||
onclick={() => closeFullscreen()}
|
||||
aria-label="Close fullscreen"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
<img
|
||||
src={generatedImages[fullscreenIndex]}
|
||||
alt="AI generated content"
|
||||
class="max-w-full max-h-full object-contain pointer-events-none"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -0,0 +1,44 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { groupModels } from "../../lib/modelUtils";
|
||||
|
||||
interface Props {
|
||||
value: string;
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
let { value = $bindable(), placeholder = "Select a model...", disabled = false }: Props = $props();
|
||||
|
||||
let grouped = $derived(groupModels($models));
|
||||
let hasModels = $derived(grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0);
|
||||
</script>
|
||||
|
||||
{#if hasModels}
|
||||
<select
|
||||
class="min-w-0 flex-1 basis-48 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
bind:value
|
||||
{disabled}
|
||||
>
|
||||
<option value="">{placeholder}</option>
|
||||
{#if grouped.local.length > 0}
|
||||
<optgroup label="Local">
|
||||
{#each grouped.local as model (model.id)}
|
||||
<option value={model.id}>{model.id}</option>
|
||||
{#if model.aliases}
|
||||
{#each model.aliases as alias (alias)}
|
||||
<option value={alias}> ↳ {alias}</option>
|
||||
{/each}
|
||||
{/if}
|
||||
{/each}
|
||||
</optgroup>
|
||||
{/if}
|
||||
{#each Object.entries(grouped.peersByProvider).sort(([a], [b]) => a.localeCompare(b)) as [peerId, peerModels] (peerId)}
|
||||
<optgroup label="Peer: {peerId}">
|
||||
{#each peerModels as model (model.id)}
|
||||
<option value={model.id}>{model.id}</option>
|
||||
{/each}
|
||||
</optgroup>
|
||||
{/each}
|
||||
</select>
|
||||
{/if}
|
||||
@@ -0,0 +1,14 @@
|
||||
<script lang="ts">
|
||||
interface Props {
|
||||
featureName: string;
|
||||
}
|
||||
|
||||
let { featureName }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="flex items-center justify-center h-full">
|
||||
<div class="text-center text-txtsecondary">
|
||||
<p class="text-lg">{featureName}</p>
|
||||
<p class="text-sm mt-2">To be implemented</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,406 @@
|
||||
<script lang="ts">
|
||||
import { models } from "../../stores/api";
|
||||
import { persistentStore } from "../../stores/persistent";
|
||||
import { rerank } from "../../lib/rerankApi";
|
||||
import { playgroundStores } from "../../stores/playgroundActivity";
|
||||
import ModelSelector from "./ModelSelector.svelte";
|
||||
|
||||
type RerankRow = { doc: string; score: number | null };
|
||||
type SortOrder = "none" | "asc" | "desc";
|
||||
type EditorMode = "table" | "json";
|
||||
|
||||
const selectedModelStore = persistentStore<string>("playground-rerank-model", "");
|
||||
|
||||
const defaultQuery = "How do LLM's work?";
|
||||
const defaultDocs = [
|
||||
"Large language models (LLMs) use transformer architectures to predict the next token in a sequence based on massive amounts of text data.",
|
||||
"LLMs are trained on diverse internet text, learning statistical patterns of language that allow them to generate coherent responses.",
|
||||
"During training, LLMs minimize a loss function that measures the difference between predicted and actual tokens across billions of examples.",
|
||||
"Attention mechanisms in transformers enable LLMs to weigh the importance of different words when generating output.",
|
||||
"Fine\u2011tuning allows a pre\u2011trained LLM to adapt to a specific downstream task with a smaller dataset.",
|
||||
"Neural networks consist of layers of interconnected neurons that adjust their weights during back\u2011propagation.",
|
||||
"The history of the Roman Empire spanned over a thousand years.",
|
||||
"Soccer is the most popular sport in many countries around the world.",
|
||||
"Quantum computing uses qubits to perform calculations that are intractable for classical computers.",
|
||||
];
|
||||
|
||||
let query = $state(defaultQuery);
|
||||
let rows = $state<RerankRow[]>([
|
||||
...defaultDocs.map((doc) => ({ doc, score: null })),
|
||||
{ doc: "", score: null },
|
||||
]);
|
||||
let isLoading = $state(false);
|
||||
let error = $state<string | null>(null);
|
||||
let usage = $state<{ prompt_tokens: number; total_tokens: number } | null>(null);
|
||||
let abortController: AbortController | null = null;
|
||||
let sortOrder = $state<SortOrder>("desc");
|
||||
let editorMode = $state<EditorMode>("table");
|
||||
let jsonText = $state("");
|
||||
let jsonError = $state<string | null>(null);
|
||||
|
||||
let hasModels = $derived($models.some((m) => !m.unlisted));
|
||||
|
||||
let canSubmit = $derived((() => {
|
||||
if (!$selectedModelStore || isLoading) return false;
|
||||
if (editorMode === "json") {
|
||||
try {
|
||||
const parsed = JSON.parse(jsonText) as Record<string, unknown>;
|
||||
return (
|
||||
typeof parsed.query === "string" &&
|
||||
parsed.query.trim() !== "" &&
|
||||
Array.isArray(parsed.documents) &&
|
||||
(parsed.documents as unknown[]).some(
|
||||
(d) => typeof d === "string" && (d as string).trim() !== ""
|
||||
)
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return query.trim() !== "" && rows.some((r) => r.doc.trim() !== "");
|
||||
})());
|
||||
|
||||
// Display rows with sort applied (display-only transform, rows[] is never mutated by sorting)
|
||||
let displayRows = $derived((() => {
|
||||
const indexed = rows.map((row, i) => ({ row, i }));
|
||||
if (sortOrder === "none") return indexed;
|
||||
return [...indexed].sort((a, b) => {
|
||||
if (a.row.score === null && b.row.score === null) return 0;
|
||||
if (a.row.score === null) return 1;
|
||||
if (b.row.score === null) return -1;
|
||||
return sortOrder === "desc"
|
||||
? b.row.score - a.row.score
|
||||
: a.row.score - b.row.score;
|
||||
});
|
||||
})());
|
||||
|
||||
// Auto-add a new empty row when the last row gets content (table mode only)
|
||||
$effect(() => {
|
||||
if (editorMode === "table" && rows[rows.length - 1]?.doc.trim() !== "") {
|
||||
rows = [...rows, { doc: "", score: null }];
|
||||
}
|
||||
});
|
||||
|
||||
// Sync loading state to activity store
|
||||
$effect(() => {
|
||||
playgroundStores.rerankLoading.set(isLoading);
|
||||
});
|
||||
|
||||
function switchToJson() {
|
||||
if (editorMode === "json") return;
|
||||
const docs = rows.filter((r) => r.doc.trim() !== "").map((r) => r.doc);
|
||||
jsonText = JSON.stringify({ query, documents: docs }, null, 2);
|
||||
jsonError = null;
|
||||
editorMode = "json";
|
||||
}
|
||||
|
||||
function switchToTable() {
|
||||
if (editorMode === "table") return;
|
||||
if (jsonText.trim() === "") {
|
||||
query = "";
|
||||
rows = [{ doc: "", score: null }];
|
||||
jsonError = null;
|
||||
editorMode = "table";
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(jsonText) as unknown;
|
||||
if (typeof parsed !== "object" || parsed === null || Array.isArray(parsed)) {
|
||||
throw new Error("Expected a JSON object");
|
||||
}
|
||||
const obj = parsed as Record<string, unknown>;
|
||||
if (typeof obj.query !== "string") throw new Error('"query" must be a string');
|
||||
if (!Array.isArray(obj.documents)) throw new Error('"documents" must be an array');
|
||||
query = obj.query;
|
||||
const newRows: RerankRow[] = (obj.documents as unknown[]).map((d) => ({
|
||||
doc: typeof d === "string" ? d : String(d),
|
||||
score: null,
|
||||
}));
|
||||
if (newRows.length === 0 || newRows[newRows.length - 1].doc.trim() !== "") {
|
||||
newRows.push({ doc: "", score: null });
|
||||
}
|
||||
rows = newRows;
|
||||
jsonError = null;
|
||||
editorMode = "table";
|
||||
} catch (err) {
|
||||
jsonError = err instanceof Error ? err.message : "Invalid JSON";
|
||||
}
|
||||
}
|
||||
|
||||
function cycleSortOrder() {
|
||||
sortOrder = sortOrder === "none" ? "desc" : sortOrder === "desc" ? "asc" : "none";
|
||||
}
|
||||
|
||||
function sortIndicator(): string {
|
||||
if (sortOrder === "desc") return " ↓";
|
||||
if (sortOrder === "asc") return " ↑";
|
||||
return "";
|
||||
}
|
||||
|
||||
async function submit() {
|
||||
if (!canSubmit) return;
|
||||
|
||||
let submitQuery: string;
|
||||
let nonEmptyEntries: { originalIndex: number; doc: string }[];
|
||||
|
||||
if (editorMode === "json") {
|
||||
// Parse JSON, sync state to table, then submit
|
||||
try {
|
||||
const parsed = JSON.parse(jsonText) as Record<string, unknown>;
|
||||
submitQuery = parsed.query as string;
|
||||
const docs = (parsed.documents as string[]).filter((d) => d.trim() !== "");
|
||||
const newRows: RerankRow[] = docs.map((d) => ({ doc: d, score: null }));
|
||||
newRows.push({ doc: "", score: null });
|
||||
rows = newRows;
|
||||
query = submitQuery;
|
||||
editorMode = "table";
|
||||
} catch {
|
||||
error = "Invalid JSON — fix before submitting";
|
||||
return;
|
||||
}
|
||||
nonEmptyEntries = rows
|
||||
.map((r, i) => ({ originalIndex: i, doc: r.doc }))
|
||||
.filter((e) => e.doc.trim() !== "");
|
||||
} else {
|
||||
submitQuery = query;
|
||||
nonEmptyEntries = rows
|
||||
.map((r, i) => ({ originalIndex: i, doc: r.doc }))
|
||||
.filter((e) => e.doc.trim() !== "");
|
||||
}
|
||||
|
||||
isLoading = true;
|
||||
error = null;
|
||||
usage = null;
|
||||
|
||||
// Clear previous scores
|
||||
rows = rows.map((r) => ({ ...r, score: null }));
|
||||
|
||||
abortController = new AbortController();
|
||||
|
||||
try {
|
||||
const response = await rerank(
|
||||
$selectedModelStore,
|
||||
submitQuery,
|
||||
nonEmptyEntries.map((e) => e.doc),
|
||||
abortController.signal
|
||||
);
|
||||
|
||||
usage = response.usage;
|
||||
|
||||
// Map result.index (position in submitted docs array) back to original rows[] index
|
||||
const updated = rows.map((r) => ({ ...r }));
|
||||
for (const result of response.results) {
|
||||
const entry = nonEmptyEntries[result.index];
|
||||
if (entry !== undefined) {
|
||||
updated[entry.originalIndex].score = result.relevance_score;
|
||||
}
|
||||
}
|
||||
rows = updated;
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// User cancelled
|
||||
} else {
|
||||
error = err instanceof Error ? err.message : "An error occurred";
|
||||
}
|
||||
} finally {
|
||||
isLoading = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
function cancel() {
|
||||
abortController?.abort();
|
||||
}
|
||||
|
||||
function clear() {
|
||||
query = defaultQuery;
|
||||
rows = [...defaultDocs.map((doc) => ({ doc, score: null })), { doc: "", score: null }];
|
||||
error = null;
|
||||
usage = null;
|
||||
sortOrder = "desc";
|
||||
jsonText = "";
|
||||
jsonError = null;
|
||||
}
|
||||
|
||||
function deleteRow(originalIndex: number) {
|
||||
if (rows.length <= 1) return;
|
||||
rows = rows.filter((_, i) => i !== originalIndex);
|
||||
}
|
||||
|
||||
function updateDoc(originalIndex: number, value: string) {
|
||||
const updated = rows.map((r) => ({ ...r }));
|
||||
updated[originalIndex].doc = value;
|
||||
rows = updated;
|
||||
}
|
||||
|
||||
function scoreColor(score: number | null): string {
|
||||
if (score === null) return "text-txtsecondary";
|
||||
if (score > 0) return "text-green-600 dark:text-green-400";
|
||||
return "text-red-500 dark:text-red-400";
|
||||
}
|
||||
|
||||
function formatScore(score: number | null): string {
|
||||
if (score === null) return "—";
|
||||
return score.toFixed(3);
|
||||
}
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent) {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
submit();
|
||||
}
|
||||
}
|
||||
|
||||
let isCleared = $derived(
|
||||
query === defaultQuery &&
|
||||
rows.every((r, i) => r.score === null && r.doc === (defaultDocs[i] ?? "")) &&
|
||||
rows.length === defaultDocs.length + 1 &&
|
||||
!jsonText.trim() &&
|
||||
!error &&
|
||||
!usage
|
||||
);
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col h-full">
|
||||
<!-- Top bar: model selector + query input (table mode) + mode toggle -->
|
||||
<div class="shrink-0 flex flex-wrap gap-2 mb-4">
|
||||
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a rerank model..." disabled={isLoading} />
|
||||
{#if editorMode === "table"}
|
||||
<input
|
||||
type="text"
|
||||
class="min-w-0 flex-1 basis-48 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
|
||||
placeholder="Query..."
|
||||
bind:value={query}
|
||||
disabled={isLoading}
|
||||
onkeydown={handleKeyDown}
|
||||
/>
|
||||
{/if}
|
||||
<!-- Table / JSON toggle -->
|
||||
<div class="flex rounded border border-gray-200 dark:border-white/10 overflow-hidden shrink-0">
|
||||
<button
|
||||
class="px-3 py-1.5 text-sm transition-colors {editorMode === 'table'
|
||||
? 'bg-primary text-btn-primary-text'
|
||||
: 'bg-surface hover:bg-secondary-hover'}"
|
||||
onclick={switchToTable}
|
||||
disabled={isLoading}
|
||||
>
|
||||
Table
|
||||
</button>
|
||||
<button
|
||||
class="px-3 py-1.5 text-sm border-l border-gray-200 dark:border-white/10 transition-colors {editorMode === 'json'
|
||||
? 'bg-primary text-btn-primary-text'
|
||||
: 'bg-surface hover:bg-secondary-hover'}"
|
||||
onclick={switchToJson}
|
||||
disabled={isLoading}
|
||||
>
|
||||
JSON
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if !hasModels}
|
||||
<div class="flex-1 flex items-center justify-center text-txtsecondary">
|
||||
<p>No models configured. Add models to your configuration to use reranking.</p>
|
||||
</div>
|
||||
{:else if editorMode === "json"}
|
||||
<!-- JSON editor -->
|
||||
<div class="flex-1 flex flex-col min-h-0 mb-4">
|
||||
<textarea
|
||||
class="flex-1 w-full font-mono text-sm px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary resize-none"
|
||||
bind:value={jsonText}
|
||||
disabled={isLoading}
|
||||
placeholder={'{\n "query": "your search query",\n "documents": [\n "document one",\n "document two"\n ]\n}'}
|
||||
spellcheck={false}
|
||||
></textarea>
|
||||
{#if jsonError}
|
||||
<p class="mt-1 text-sm text-red-500">{jsonError}</p>
|
||||
{/if}
|
||||
</div>
|
||||
{:else}
|
||||
<!-- Document table -->
|
||||
<div class="flex-1 overflow-y-auto mb-4 border border-gray-200 dark:border-white/10 rounded">
|
||||
<table class="w-full border-collapse table-fixed">
|
||||
<colgroup>
|
||||
<col class="w-auto" />
|
||||
<col style="width: 120px" />
|
||||
<col style="width: 40px" />
|
||||
</colgroup>
|
||||
<thead class="sticky top-0 bg-surface border-b border-gray-200 dark:border-white/10">
|
||||
<tr>
|
||||
<th class="px-3 py-2 text-left text-sm font-medium text-txtsecondary">Document</th>
|
||||
<th
|
||||
class="px-3 py-2 text-right text-sm font-medium text-txtsecondary cursor-pointer select-none hover:text-txtprimary transition-colors"
|
||||
onclick={cycleSortOrder}
|
||||
>
|
||||
Score{sortIndicator()}
|
||||
</th>
|
||||
<th class="px-2 py-2"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each displayRows as { row, i } (i)}
|
||||
<tr class="border-b border-gray-100 dark:border-white/5 last:border-0">
|
||||
<td class="px-3 py-1.5">
|
||||
<input
|
||||
type="text"
|
||||
class="w-full bg-transparent focus:outline-none focus:ring-1 focus:ring-primary rounded px-1 py-0.5"
|
||||
placeholder={i === rows.length - 1 ? "Add document..." : "Document text..."}
|
||||
value={row.doc}
|
||||
oninput={(e) => updateDoc(i, (e.target as HTMLInputElement).value)}
|
||||
disabled={isLoading}
|
||||
onkeydown={handleKeyDown}
|
||||
/>
|
||||
</td>
|
||||
<td class="px-3 py-1.5 text-right font-mono text-sm {scoreColor(row.score)}">
|
||||
{#if isLoading && row.score === null && row.doc.trim() !== ""}
|
||||
<span class="inline-block w-4 h-4 border-2 border-current border-t-transparent rounded-full animate-spin align-middle"></span>
|
||||
{:else}
|
||||
{formatScore(row.score)}
|
||||
{/if}
|
||||
</td>
|
||||
<td class="px-2 py-1.5 text-center">
|
||||
<button
|
||||
class="w-7 h-7 flex items-center justify-center text-txtsecondary hover:text-red-500 transition-colors rounded disabled:opacity-30 disabled:cursor-not-allowed"
|
||||
onclick={() => deleteRow(i)}
|
||||
disabled={rows.length <= 1}
|
||||
tabindex="-1"
|
||||
aria-label="Remove row"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Bottom toolbar -->
|
||||
{#if hasModels}
|
||||
<div class="shrink-0 flex flex-wrap items-center gap-2">
|
||||
{#if isLoading}
|
||||
<button class="btn bg-red-500 hover:bg-red-600 text-white" onclick={cancel}>
|
||||
Cancel
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="btn bg-primary text-btn-primary-text hover:opacity-90"
|
||||
onclick={submit}
|
||||
disabled={!canSubmit}
|
||||
>
|
||||
Rerank
|
||||
</button>
|
||||
<button class="btn" onclick={clear} disabled={isCleared}>
|
||||
Clear
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
{#if error}
|
||||
<span class="text-sm text-red-500 ml-2">{error}</span>
|
||||
{:else if usage}
|
||||
<span class="text-sm text-txtsecondary ml-2">{usage.total_tokens} tokens</span>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||