Compare commits
104 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e8d4384cd2 | |||
| ce28485be2 | |||
| 3cd7837b1f | |||
| 0b31ccacc1 | |||
| 5938dbee8f | |||
| 66639e83f7 | |||
| 625b296720 | |||
| 231e62291c | |||
| 57ac666598 | |||
| 69728301f5 | |||
| c176fa70f1 | |||
| 5e3c646829 | |||
| c3f0d43e6e | |||
| f6cf9f5844 | |||
| 121fd93ad8 | |||
| 17233e9278 | |||
| 4866d16c3e | |||
| 35193f82f1 | |||
| 40e39f7a86 | |||
| a9d840ffd7 | |||
| 7b2b82777f | |||
| d87f0ce2c5 | |||
| 06bc6a614c | |||
| a37b4866d8 | |||
| 981910d734 | |||
| a185efe37e | |||
| 1dd1aadf93 | |||
| 955900972a | |||
| c2c8cfaf81 | |||
| 1e440770ea | |||
| c794273c83 | |||
| 6574a52cbb | |||
| 8fabc75634 | |||
| e5e7391b6d | |||
| 2c282dccad | |||
| 916d13f5bd | |||
| a3725e7d09 | |||
| 15bd55d3a9 | |||
| c3c258a55d | |||
| 29a38fde0d | |||
| d569681daa | |||
| 24efdb76b1 | |||
| cc77139ff8 | |||
| 390a35bf93 | |||
| 181f71ca11 | |||
| 49546e2cf2 | |||
| 2c078964f4 | |||
| 175bb36fb1 | |||
| aedb640471 | |||
| 2f377f6dc6 | |||
| 64e4c79fc3 | |||
| 19fb5f35e9 | |||
| b45102bde8 | |||
| 1688bdd1e9 | |||
| d33d51fa75 | |||
| e3bf065574 | |||
| 3e52144058 | |||
| d5e52d7d00 | |||
| 17e5263a76 | |||
| 8d6d949ec3 | |||
| b5fde8eb6d | |||
| 7eef5defb8 | |||
| bc01e6f539 | |||
| 0462e3dc3f | |||
| 7b20fc011b | |||
| 20738f3623 | |||
| cdea7d16bd | |||
| 5de387dbf9 | |||
| 6f8e7ccb57 | |||
| 4384315b44 | |||
| 6439ab1515 | |||
| f94226122c | |||
| 7493618fdc | |||
| 205efd40a1 | |||
| 14207f8492 | |||
| 4e850c2834 | |||
| 75fced579e | |||
| b73f367f22 | |||
| 8f2137c72b | |||
| 124007cc98 | |||
| eb5bfff0b0 | |||
| 3edb180c08 | |||
| 66d555e625 | |||
| 4f863fd9fc | |||
| 267c030457 | |||
| c19309fe7e | |||
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f | |||
| 37d74efc2d | |||
| 22e098ac8b | |||
| 9864f9f517 | |||
| 53b32f3601 | |||
| 565c44766d | |||
| e6a9e210ba | |||
| d3f329f924 | |||
| 98879b38c1 | |||
| 7b3b0f5eae | |||
| 021ccceef1 | |||
| f03871c50a | |||
| dc00d17abe | |||
| dea98733c3 |
@@ -4,12 +4,19 @@ early_access: false
|
|||||||
reviews:
|
reviews:
|
||||||
profile: "chill"
|
profile: "chill"
|
||||||
request_changes_workflow: false
|
request_changes_workflow: false
|
||||||
high_level_summary: true
|
high_level_summary: false
|
||||||
poem: false
|
poem: false
|
||||||
review_status: true
|
review_status: true
|
||||||
collapse_walkthrough: false
|
collapse_walkthrough: false
|
||||||
|
sequence_diagrams: false
|
||||||
|
finishing_touches:
|
||||||
|
docstrings:
|
||||||
|
enabled: false
|
||||||
auto_review:
|
auto_review:
|
||||||
enabled: true
|
enabled: true
|
||||||
drafts: false
|
drafts: false
|
||||||
chat:
|
chat:
|
||||||
auto_reply: true
|
auto_reply: true
|
||||||
|
issue_enrichment:
|
||||||
|
planning:
|
||||||
|
enabled: false
|
||||||
|
|||||||
@@ -4,11 +4,15 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- "config-schema.json"
|
- "config-schema.json"
|
||||||
|
- "config.example.yaml"
|
||||||
|
- ".github/workflows/config-schema.yml"
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
paths:
|
paths:
|
||||||
- "config-schema.json"
|
- "config-schema.json"
|
||||||
|
- "config.example.yaml"
|
||||||
|
- ".github/workflows/config-schema.yml"
|
||||||
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
@@ -39,3 +43,14 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
echo "✓ config-schema.json is valid"
|
echo "✓ config-schema.json is valid"
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.x"
|
||||||
|
|
||||||
|
- name: Install check-jsonschema
|
||||||
|
run: pip install check-jsonschema
|
||||||
|
|
||||||
|
- name: Validate config.example.yaml against schema
|
||||||
|
run: check-jsonschema --schemafile config-schema.json config.example.yaml
|
||||||
|
|||||||
@@ -10,17 +10,44 @@ on:
|
|||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
# Run on workflow file changes (without pushing)
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/containers.yml'
|
||||||
|
- 'docker/build-container.sh'
|
||||||
|
- 'docker/*.Containerfile'
|
||||||
|
|
||||||
|
# grant permissions on GITHUB_TOKEN to publish packages
|
||||||
|
# ref: https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#publishing-a-package-using-an-action
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
id-token: write
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-and-push:
|
build-and-push:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
platform: [intel, cuda, vulkan, cpu, musa]
|
platform: [intel, cuda, cuda13, vulkan, cpu, musa, rocm]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Free up disk space
|
||||||
|
if: matrix.platform == 'rocm'
|
||||||
|
run: |
|
||||||
|
echo "Before cleanup:"
|
||||||
|
df -h
|
||||||
|
sudo rm -rf /usr/share/dotnet
|
||||||
|
sudo rm -rf /usr/local/lib/android
|
||||||
|
sudo rm -rf /opt/ghc
|
||||||
|
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||||
|
sudo docker system prune -af
|
||||||
|
echo "After cleanup:"
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Log in to GitHub Container Registry
|
- name: Log in to GitHub Container Registry
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
@@ -31,7 +58,7 @@ jobs:
|
|||||||
- name: Run build-container
|
- name: Run build-container
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
|
||||||
|
|
||||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||||
|
|||||||
@@ -3,9 +3,25 @@ name: Windows CI
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ "main" ]
|
branches: [ "main" ]
|
||||||
|
# only run when backend source changes
|
||||||
|
# cmd/ is excluded because it contains utilities without tests
|
||||||
|
paths:
|
||||||
|
- '**/*.go'
|
||||||
|
- '!cmd/**'
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- 'Makefile'
|
||||||
|
- '.github/workflows/go-ci-windows.yml'
|
||||||
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ "main" ]
|
branches: [ "main" ]
|
||||||
|
paths:
|
||||||
|
- '**/*.go'
|
||||||
|
- '!cmd/**'
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- 'Makefile'
|
||||||
|
- '.github/workflows/go-ci-windows.yml'
|
||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
@@ -28,7 +44,7 @@ jobs:
|
|||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: ./build
|
path: ./build
|
||||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
# necessary for testing proxy/Process swapping
|
# necessary for testing proxy/Process swapping
|
||||||
- name: Create simple-responder
|
- name: Create simple-responder
|
||||||
@@ -43,7 +59,7 @@ jobs:
|
|||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
path: ./build
|
path: ./build
|
||||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
- name: Test all
|
- name: Test all
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -2,53 +2,68 @@ name: Linux CI
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ "main" ]
|
branches: ["main"]
|
||||||
|
# only run when backend source changes
|
||||||
|
# cmd/ is excluded because it contains utilities without tests
|
||||||
|
paths:
|
||||||
|
- "**/*.go"
|
||||||
|
- "!cmd/**"
|
||||||
|
- "go.mod"
|
||||||
|
- "go.sum"
|
||||||
|
- "Makefile"
|
||||||
|
- ".github/workflows/go-ci.yml"
|
||||||
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ "main" ]
|
branches: ["main"]
|
||||||
|
paths:
|
||||||
|
- "**/*.go"
|
||||||
|
- "!cmd/**"
|
||||||
|
- "go.mod"
|
||||||
|
- "go.sum"
|
||||||
|
- "Makefile"
|
||||||
|
- ".github/workflows/go-ci.yml"
|
||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
run-tests:
|
run-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: '1.23'
|
go-version-file: go.mod
|
||||||
|
|
||||||
# Only run in this linux based runner
|
# Only run in this linux based runner
|
||||||
- name: Check Formatting
|
- name: Check Formatting
|
||||||
run: |
|
run: |
|
||||||
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||||
gofmt -l . | grep -v 'event/.*_test.go'
|
gofmt -l . | grep -v 'event/.*_test.go'
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
# cache simple-responder to save the build time
|
# cache simple-responder to save the build time
|
||||||
- name: Restore Simple Responder
|
- name: Restore Simple Responder
|
||||||
id: restore-simple-responder
|
id: restore-simple-responder
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: ./build
|
path: ./build
|
||||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
# necessary for testing proxy/Process swapping
|
# necessary for testing proxy/Process swapping
|
||||||
- name: Create simple-responder
|
- name: Create simple-responder
|
||||||
run: make simple-responder
|
run: make simple-responder
|
||||||
|
|
||||||
- name: Save Simple Responder
|
- name: Save Simple Responder
|
||||||
# nothing new to save ... skip this step
|
# nothing new to save ... skip this step
|
||||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
id: save-simple-responder
|
id: save-simple-responder
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
path: ./build
|
path: ./build
|
||||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
- name: Test all
|
- name: Test all
|
||||||
run: make test-all
|
run: make test-all
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ name: goreleaser
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- "*"
|
||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
tag:
|
tag:
|
||||||
description: 'Tag version to release (e.g. v144)'
|
description: "Tag version to release (e.g. v144)"
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
@@ -19,35 +19,30 @@ jobs:
|
|||||||
goreleaser:
|
goreleaser:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
-
|
- name: Checkout
|
||||||
name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||||
-
|
- name: Set up Go
|
||||||
name: Set up Go
|
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
-
|
- name: Set up Node.js
|
||||||
name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: '23'
|
node-version: "24"
|
||||||
-
|
- name: Install dependencies and build UI
|
||||||
name: Install dependencies and build UI
|
|
||||||
run: |
|
run: |
|
||||||
cd ui
|
cd ui-svelte
|
||||||
npm ci
|
npm ci
|
||||||
npm run build
|
npm run build
|
||||||
|
|
||||||
-
|
- name: Run GoReleaser
|
||||||
name: Run GoReleaser
|
|
||||||
uses: goreleaser/goreleaser-action@v6
|
uses: goreleaser/goreleaser-action@v6
|
||||||
with:
|
with:
|
||||||
# either 'goreleaser' (default) or 'goreleaser-pro'
|
# either 'goreleaser' (default) or 'goreleaser-pro'
|
||||||
distribution: goreleaser
|
distribution: goreleaser
|
||||||
# 'latest', 'nightly', or a semver
|
# 'latest', 'nightly', or a semver
|
||||||
version: '~> v2'
|
version: "~> v2"
|
||||||
args: release --clean
|
args: release --clean
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@@ -76,4 +71,4 @@ jobs:
|
|||||||
"release": {
|
"release": {
|
||||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
name: UI Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
paths:
|
||||||
|
- 'ui-svelte/**'
|
||||||
|
- '.github/workflows/ui-tests.yml'
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
paths:
|
||||||
|
- 'ui-svelte/**'
|
||||||
|
- '.github/workflows/ui-tests.yml'
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
run-tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '24'
|
||||||
|
cache: 'npm'
|
||||||
|
cache-dependency-path: ui-svelte/package-lock.json
|
||||||
|
|
||||||
|
- name: Run UI tests
|
||||||
|
run: make test-ui
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
name: Build Unified Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: "37 5 * * *"
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
llama_cpp_ref:
|
||||||
|
description: "llama.cpp commit hash, tag, or branch"
|
||||||
|
required: false
|
||||||
|
default: "master"
|
||||||
|
whisper_ref:
|
||||||
|
description: "whisper.cpp commit hash, tag, or branch"
|
||||||
|
required: false
|
||||||
|
default: "master"
|
||||||
|
sd_ref:
|
||||||
|
description: "stable-diffusion.cpp commit hash, tag, or branch"
|
||||||
|
required: false
|
||||||
|
default: "master"
|
||||||
|
ik_llama_ref:
|
||||||
|
description: "ik_llama.cpp commit hash, tag, or branch (CUDA only)"
|
||||||
|
required: false
|
||||||
|
default: "main"
|
||||||
|
llama_swap_version:
|
||||||
|
description: "llama-swap version (e.g. v198, latest, main)"
|
||||||
|
required: false
|
||||||
|
default: "main"
|
||||||
|
build_cuda:
|
||||||
|
description: "Build CUDA image"
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: true
|
||||||
|
build_vulkan:
|
||||||
|
description: "Build Vulkan image"
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: true
|
||||||
|
push_to_ghcr:
|
||||||
|
description: "Push images to ghcr.io"
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
setup:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
|
steps:
|
||||||
|
- id: set-matrix
|
||||||
|
run: |
|
||||||
|
backends=()
|
||||||
|
# schedule uses defaults (build both); workflow_dispatch respects inputs
|
||||||
|
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_cuda }}" == "true" ]]; then
|
||||||
|
backends+=("cuda")
|
||||||
|
fi
|
||||||
|
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_vulkan }}" == "true" ]]; then
|
||||||
|
backends+=("vulkan")
|
||||||
|
fi
|
||||||
|
matrix=$(printf '%s\n' "${backends[@]}" | jq -R . | jq -sc .)
|
||||||
|
echo "matrix=$matrix" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
build:
|
||||||
|
needs: setup
|
||||||
|
if: ${{ needs.setup.outputs.matrix != '[]' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Free up disk space
|
||||||
|
run: |
|
||||||
|
echo "Before cleanup:"
|
||||||
|
df -h
|
||||||
|
sudo rm -rf /usr/share/dotnet
|
||||||
|
sudo rm -rf /usr/local/lib/android
|
||||||
|
sudo rm -rf /opt/ghc
|
||||||
|
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||||
|
sudo docker system prune -af
|
||||||
|
echo "After cleanup:"
|
||||||
|
df -h
|
||||||
|
|
||||||
|
# On GitHub Actions runners, create a fresh builder.
|
||||||
|
# When running locally under act, skip this and reuse the existing
|
||||||
|
# llama-swap-builder (which has ccache warm) to avoid exhausting disk.
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
if: ${{ !env.ACT }}
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Log in to GitHub Container Registry
|
||||||
|
if: ${{ !env.ACT }}
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Build unified Docker image (${{ matrix.backend }})
|
||||||
|
env:
|
||||||
|
LLAMA_REF: ${{ inputs.llama_cpp_ref || 'master' }}
|
||||||
|
WHISPER_REF: ${{ inputs.whisper_ref || 'master' }}
|
||||||
|
SD_REF: ${{ inputs.sd_ref || 'master' }}
|
||||||
|
IK_LLAMA_REF: ${{ inputs.ik_llama_ref || 'main' }}
|
||||||
|
LS_VERSION: ${{ inputs.llama_swap_version || 'main' }}
|
||||||
|
DOCKER_IMAGE_TAG: ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}
|
||||||
|
# When running under act, use the local builder that has warm ccache.
|
||||||
|
# On GitHub Actions, BUILDX_BUILDER is unset so docker uses the builder
|
||||||
|
# created by setup-buildx-action above.
|
||||||
|
BUILDX_BUILDER: ${{ env.ACT == 'true' && 'llama-swap-builder' || '' }}
|
||||||
|
run: |
|
||||||
|
chmod +x docker/unified/build-image.sh
|
||||||
|
docker/unified/build-image.sh --${{ matrix.backend }}
|
||||||
|
|
||||||
|
- name: Push to GitHub Container Registry
|
||||||
|
if: ${{ !env.ACT && (github.event_name == 'schedule' || inputs.push_to_ghcr == true) }}
|
||||||
|
run: |
|
||||||
|
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
|
||||||
|
DATE_TAG=$(date -u +%Y-%m-%d)
|
||||||
|
|
||||||
|
docker push "${BASE_TAG}"
|
||||||
|
docker tag "${BASE_TAG}" "${BASE_TAG}-${DATE_TAG}"
|
||||||
|
docker push "${BASE_TAG}-${DATE_TAG}"
|
||||||
|
|
||||||
|
ROOTLESS_TAG="${BASE_TAG}-rootless"
|
||||||
|
docker push "${ROOTLESS_TAG}"
|
||||||
|
docker tag "${ROOTLESS_TAG}" "${ROOTLESS_TAG}-${DATE_TAG}"
|
||||||
|
docker push "${ROOTLESS_TAG}-${DATE_TAG}"
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
## Project Description:
|
||||||
|
|
||||||
|
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||||
|
|
||||||
|
## Tech stack
|
||||||
|
|
||||||
|
- golang
|
||||||
|
- typescript, vite and svelt5 for UI (located in ui/)
|
||||||
|
|
||||||
|
## Workflow Tasks
|
||||||
|
|
||||||
|
- when summarizing changes only include details that require further action
|
||||||
|
- just say "Done." when there is no further action
|
||||||
|
- use the github CLI `gh` to create pull requests and work with github
|
||||||
|
- Rules for creating pull requests:
|
||||||
|
- keep them short and focused on changes.
|
||||||
|
- never include a test plan
|
||||||
|
- write the summary using the same style rules as commit message
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||||
|
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||||
|
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
|
||||||
|
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||||
|
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||||
|
- Use `make test-ui` after making changes to the UI in ui-svelte/
|
||||||
|
|
||||||
|
### Commit message example format:
|
||||||
|
|
||||||
|
```
|
||||||
|
proxy: add new feature
|
||||||
|
|
||||||
|
Add new feature that implements functionality X and Y.
|
||||||
|
|
||||||
|
- key change 1
|
||||||
|
- key change 2
|
||||||
|
- key change 3
|
||||||
|
|
||||||
|
fixes #123
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code Reviews
|
||||||
|
|
||||||
|
- use three levels High, Medium, Low severity
|
||||||
|
- label each discovered issue with a label like H1, M2, L3 respectively
|
||||||
|
- High severity are must fix issues (security, race conditions, critical bugs)
|
||||||
|
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
|
||||||
|
- Low severity are nice to have changes and nits
|
||||||
|
- Include a suggestion with each discovered item
|
||||||
|
- Limit your code review to three items with the highest priority first
|
||||||
|
- Double check your discovered items and recommended remediations
|
||||||
@@ -1,43 +1 @@
|
|||||||
# Project: llama-swap
|
@AGENTS.md
|
||||||
|
|
||||||
## Project Description:
|
|
||||||
|
|
||||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
|
||||||
|
|
||||||
## Tech stack
|
|
||||||
|
|
||||||
- golang
|
|
||||||
- typescript, vite and react for UI (ui/)
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
|
||||||
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
|
||||||
|
|
||||||
## Workflow Tasks
|
|
||||||
|
|
||||||
### Plan Improvements
|
|
||||||
|
|
||||||
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
|
||||||
|
|
||||||
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
|
||||||
|
|
||||||
- Identify any inconsistencies.
|
|
||||||
- Expand plans out to be detailed specification of requirements and changes to be made.
|
|
||||||
- Plans should have at least these sections:
|
|
||||||
- Title - very short, describes changes
|
|
||||||
- Overview: A more detailed summary of goal and outcomes desired
|
|
||||||
- Design Requirements: Detailed descriptions of what needs to be done
|
|
||||||
- Testing Plan: Tests to be implemented
|
|
||||||
- Checklist: A detailed list of changes to be made
|
|
||||||
|
|
||||||
Look for "plan expansion" as explicit instructions to improve a plan.
|
|
||||||
|
|
||||||
### Implementation of plans
|
|
||||||
|
|
||||||
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
|
||||||
|
|
||||||
## General Rules
|
|
||||||
|
|
||||||
- when summarizing changes only include details that require further action (action items)
|
|
||||||
- when there are no action items, just say "Done."
|
|
||||||
|
|||||||
@@ -36,11 +36,11 @@ test-all: proxy/ui_dist/placeholder.txt
|
|||||||
go test -race -count=1 ./proxy/...
|
go test -race -count=1 ./proxy/...
|
||||||
|
|
||||||
ui/node_modules:
|
ui/node_modules:
|
||||||
cd ui && npm install
|
cd ui-svelte && npm install
|
||||||
|
|
||||||
# build react UI
|
# build react UI
|
||||||
ui: ui/node_modules
|
ui: ui/node_modules
|
||||||
cd ui && npm run build
|
cd ui-svelte && npm run build
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
mac: ui
|
mac: ui
|
||||||
@@ -48,9 +48,14 @@ mac: ui
|
|||||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||||
|
|
||||||
# Build Linux binary
|
# Build Linux binary
|
||||||
linux: ui
|
linux: linux-arm64 linux-amd64
|
||||||
@echo "Building Linux binary..."
|
|
||||||
|
linux-amd64: ui
|
||||||
|
@echo "Building Linux AMD64 binary..."
|
||||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||||
|
|
||||||
|
linux-arm64: ui
|
||||||
|
@echo "Building Linux ARM64 binary..."
|
||||||
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||||
|
|
||||||
# Build Windows binary
|
# Build Windows binary
|
||||||
@@ -92,5 +97,9 @@ wol-proxy: $(BUILD_DIR)
|
|||||||
@echo "Building wol-proxy"
|
@echo "Building wol-proxy"
|
||||||
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
||||||
|
|
||||||
|
test-ui:
|
||||||
|
cd ui-svelte && npm ci && npm run check && npm test
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
|
.PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev test-ui wol-proxy
|
||||||
|
.PHONE: linux linux-arm64 linux-amd64
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||

|

|
||||||

|

|
||||||

|

|
||||||

|

|
||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
|
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
|
||||||
|
|
||||||
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
||||||
|
|
||||||
@@ -13,18 +13,29 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
|||||||
|
|
||||||
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
||||||
- ✅ On-demand model switching
|
- ✅ On-demand model switching
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
|
||||||
- future proof, upgrade your inference servers at any time.
|
- future proof, upgrade your inference servers at any time.
|
||||||
- ✅ OpenAI API supported endpoints:
|
- ✅ OpenAI API supported endpoints:
|
||||||
- `v1/completions`
|
- `v1/completions`
|
||||||
- `v1/chat/completions`
|
- `v1/chat/completions`
|
||||||
|
- `v1/responses`
|
||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||||
|
- `v1/audio/voices`
|
||||||
|
- `v1/images/generations`
|
||||||
|
- `v1/images/edits`
|
||||||
|
- ✅ Anthropic API supported endpoints:
|
||||||
|
- `v1/messages`
|
||||||
|
- `v1/messages/count_tokens`
|
||||||
- ✅ llama-server (llama.cpp) supported endpoints
|
- ✅ llama-server (llama.cpp) supported endpoints
|
||||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||||
- `/infill` - for code infilling
|
- `/infill` - for code infilling
|
||||||
- `/completion` - for completion endpoint
|
- `/completion` - for completion endpoint
|
||||||
|
- ✅ SDAPI via [stable-diffusion.cpp's server](https://github.com/leejet/stable-diffusion.cpp/tree/master/examples/server)
|
||||||
|
- `/sdapi/v1/txt2img`
|
||||||
|
- `/sdapi/v1/img2img`
|
||||||
|
- `/sdapi/v1/loras` - requires `model` in request body to fetch the correct loras
|
||||||
- ✅ llama-swap API
|
- ✅ llama-swap API
|
||||||
- `/ui` - web UI
|
- `/ui` - web UI
|
||||||
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
@@ -32,22 +43,34 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
|||||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||||
- `/log` - remote log monitoring
|
- `/log` - remote log monitoring
|
||||||
- `/health` - just returns "OK"
|
- `/health` - just returns "OK"
|
||||||
|
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||||
- ✅ Customizable
|
- ✅ Customizable
|
||||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
- Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643))
|
||||||
- Automatic unloading of models after timeout by setting a `ttl`
|
- Automatic unloading of models after timeout by setting a `ttl`
|
||||||
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
||||||
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||||
|
|
||||||
### Web UI
|
### Web UI
|
||||||
|
|
||||||
llama-swap includes a real time web interface for monitoring logs and controlling models:
|
llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
|
||||||
|
|
||||||
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
|
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
|
||||||
|
|
||||||
|
View detailed token metrics:
|
||||||
|
|
||||||
The Activity Page shows recent requests:
|
<img width="1111" height="515" alt="image" src="https://github.com/user-attachments/assets/64bfb280-d7a3-4126-971a-a128fd40410c" />
|
||||||
|
|
||||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
Inspect request and responses:
|
||||||
|
|
||||||
|
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
|
||||||
|
|
||||||
|
Manually load and unload models:
|
||||||
|
|
||||||
|
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
|
||||||
|
|
||||||
|
Real time log streaming:
|
||||||
|
|
||||||
|
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -61,7 +84,8 @@ llama-swap can be installed in multiple ways
|
|||||||
|
|
||||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||||
|
|
||||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc).
|
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
|
||||||
|
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||||
@@ -71,6 +95,14 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
|||||||
-v /path/to/models:/models \
|
-v /path/to/models:/models \
|
||||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||||
ghcr.io/mostlygeek/llama-swap:cuda
|
ghcr.io/mostlygeek/llama-swap:cuda
|
||||||
|
|
||||||
|
# configuration hot reload supported with a
|
||||||
|
# directory volume mount
|
||||||
|
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||||
|
-v /path/to/models:/models \
|
||||||
|
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||||
|
-v /path/to/config:/config \
|
||||||
|
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
|
||||||
```
|
```
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -89,6 +121,9 @@ docker pull ghcr.io/mostlygeek/llama-swap:musa
|
|||||||
# tagged llama-swap, platform and llama-server version images
|
# tagged llama-swap, platform and llama-server version images
|
||||||
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
||||||
|
|
||||||
|
# non-root cuda
|
||||||
|
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -145,7 +180,7 @@ That's all you need to get started:
|
|||||||
Almost all configuration settings are optional and can be added one step at a time:
|
Almost all configuration settings are optional and can be added one step at a time:
|
||||||
|
|
||||||
- Advanced features
|
- Advanced features
|
||||||
- `groups` to run multiple models at once
|
- `matrix` to run concurrent models with a custom swap logic DSL
|
||||||
- `hooks` to run things on startup
|
- `hooks` to run things on startup
|
||||||
- `macros` reusable snippets
|
- `macros` reusable snippets
|
||||||
- Model customization
|
- Model customization
|
||||||
@@ -163,7 +198,7 @@ See the [configuration documentation](docs/configuration.md) for all options.
|
|||||||
|
|
||||||
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
|
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
|
||||||
|
|
||||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, using a `matrix` allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||||
|
|
||||||
## Reverse Proxy Configuration (nginx)
|
## Reverse Proxy Configuration (nginx)
|
||||||
|
|
||||||
@@ -191,23 +226,26 @@ As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. H
|
|||||||
|
|
||||||
## Monitoring Logs on the CLI
|
## Monitoring Logs on the CLI
|
||||||
|
|
||||||
```shell
|
```sh
|
||||||
# sends up to the last 10KB of logs
|
# sends up to the last 10KB of logs
|
||||||
curl http://host/logs'
|
$ curl http://host/logs
|
||||||
|
|
||||||
# streams combined logs
|
# streams combined logs
|
||||||
curl -Ns 'http://host/logs/stream'
|
curl -Ns http://host/logs/stream
|
||||||
|
|
||||||
# just llama-swap's logs
|
# stream llama-swap's proxy status logs
|
||||||
curl -Ns 'http://host/logs/stream/proxy'
|
curl -Ns http://host/logs/stream/proxy
|
||||||
|
|
||||||
# just upstream's logs
|
# stream logs from upstream processes that llama-swap loads
|
||||||
curl -Ns 'http://host/logs/stream/upstream'
|
curl -Ns http://host/logs/stream/upstream
|
||||||
|
|
||||||
|
# stream logs only from a specific model
|
||||||
|
curl -Ns http://host/logs/stream/{model_id}
|
||||||
|
|
||||||
# stream and filter logs with linux pipes
|
# stream and filter logs with linux pipes
|
||||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||||
|
|
||||||
# skips history and just streams new log entries
|
# appending ?no-history will disable sending buffered history first
|
||||||
curl -Ns 'http://host/logs/stream?no-history'
|
curl -Ns 'http://host/logs/stream?no-history'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
# Replace ring.Ring with Efficient Circular Byte Buffer
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
|
||||||
|
|
||||||
|
## Current Issues
|
||||||
|
|
||||||
|
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
|
||||||
|
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
|
||||||
|
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
|
||||||
|
4. Linked list structure has poor cache locality and pointer overhead
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### New CircularBuffer Type
|
||||||
|
|
||||||
|
Create a simple circular byte buffer with:
|
||||||
|
- Single pre-allocated `[]byte` of fixed capacity (10KB)
|
||||||
|
- `head` and `size` integers to track write position and data length
|
||||||
|
- No per-write allocations
|
||||||
|
|
||||||
|
### API Requirements
|
||||||
|
|
||||||
|
The new buffer must support:
|
||||||
|
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
|
||||||
|
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
|
||||||
|
|
||||||
|
### Implementation Details
|
||||||
|
|
||||||
|
```go
|
||||||
|
type circularBuffer struct {
|
||||||
|
data []byte // pre-allocated capacity
|
||||||
|
head int // next write position
|
||||||
|
size int // current number of bytes stored (0 to cap)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Write logic:**
|
||||||
|
- If `len(p) >= capacity`: just keep the last `capacity` bytes
|
||||||
|
- Otherwise: write bytes at `head`, wrapping around if needed
|
||||||
|
- Update `head` and `size` accordingly
|
||||||
|
- Data is copied into the internal buffer (not stored by reference)
|
||||||
|
|
||||||
|
**GetHistory logic:**
|
||||||
|
- Calculate start position: `(head - size + cap) % cap`
|
||||||
|
- If not wrapped: single slice copy
|
||||||
|
- If wrapped: two copies (end of buffer + beginning)
|
||||||
|
- Returns a **new slice** (copy), not a view into internal buffer
|
||||||
|
|
||||||
|
### Immutability Guarantees (must preserve)
|
||||||
|
|
||||||
|
Per existing tests:
|
||||||
|
1. Modifying input `[]byte` after `Write()` must not affect stored data
|
||||||
|
2. `GetHistory()` returns independent copy - modifications don't affect buffer
|
||||||
|
|
||||||
|
## Files to Modify
|
||||||
|
|
||||||
|
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
Existing tests in `logMonitor_test.go` should continue to pass:
|
||||||
|
- `TestLogMonitor` - Basic write/read and subscriber notification
|
||||||
|
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
|
||||||
|
- `TestWrite_LogTimeFormat` - Timestamp formatting
|
||||||
|
|
||||||
|
Add new tests:
|
||||||
|
- Test buffer wrap-around behavior
|
||||||
|
- Test large writes that exceed buffer capacity
|
||||||
|
- Test exact capacity boundary conditions
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Create `circularBuffer` struct in `logMonitor.go`
|
||||||
|
- [ ] Implement `Write()` method for circular buffer
|
||||||
|
- [ ] Implement `GetHistory()` method for circular buffer
|
||||||
|
- [ ] Update `LogMonitor` struct to use new buffer
|
||||||
|
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
|
||||||
|
- [ ] Update `LogMonitor.Write()` to use new buffer
|
||||||
|
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
|
||||||
|
- [ ] Remove `"container/ring"` import
|
||||||
|
- [ ] Run `make test-dev` to verify existing tests pass
|
||||||
|
- [ ] Add wrap-around test case
|
||||||
|
- [ ] Run `make test-all` for final validation
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
# Improve Testability (#655)
|
||||||
|
|
||||||
|
## Current Pain Points
|
||||||
|
|
||||||
|
1. **Tests bypass config loading** - ~80% of tests build `config.Config` structs directly, skipping YAML parsing, env var substitution, macro expansion, and `${PORT}` assignment. Config bugs in those paths go untested.
|
||||||
|
|
||||||
|
2. **simple-responder is everywhere** - Every proxy/routing test launches a real subprocess, waits for health checks (~healthCheckTimeout: 15), and manages process lifecycle just to test HTTP routing. Most of that overhead is wasted.
|
||||||
|
|
||||||
|
3. **Port counter is fragile** - A global `nextTestPort` counter starting at 12000 with a mutex. Parallel tests or leftover processes can collide.
|
||||||
|
|
||||||
|
## Stages
|
||||||
|
|
||||||
|
### Stage 1: YAML-based test config helper
|
||||||
|
|
||||||
|
**Goal:** Tests go through the real `LoadConfigFromReader` path instead of hand-building structs.
|
||||||
|
|
||||||
|
**Effort:** Low | **Impact:** Config bugs caught earlier | **Risk:** None
|
||||||
|
|
||||||
|
Create a test helper in `proxy/helpers_test.go`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// testConfigFromYAML substitutes simple-responder paths and loads through
|
||||||
|
// the real config pipeline (env vars, macros, port assignment, etc.)
|
||||||
|
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||||
|
t.Helper()
|
||||||
|
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||||
|
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Tests would then look like:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
|
config := testConfigFromYAML(t, `
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
logLevel: error
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
|
||||||
|
model2:
|
||||||
|
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model2
|
||||||
|
`)
|
||||||
|
proxy := New(config)
|
||||||
|
// ... same assertions
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why this stage first:** Zero production code changes. Pure test-side refactoring. Can be done incrementally - migrate tests one at a time. Each migrated test now validates the full config pipeline.
|
||||||
|
|
||||||
|
**Scope:** ~20-30 tests in `proxymanager_test.go`, `processgroup_test.go`, `peerproxy_test.go`.
|
||||||
|
|
||||||
|
### Stage 2: Injected test handler (eliminate simple-responder for routing tests)
|
||||||
|
|
||||||
|
**Goal:** Replace simple-responder subprocess launches with an injected `http.Handler` for tests that don't specifically test process lifecycle.
|
||||||
|
|
||||||
|
**Effort:** Medium | **Impact:** 10-100x faster routing tests | **Risk:** Low (additive, no existing code broken)
|
||||||
|
|
||||||
|
Add a `testHandler http.Handler` field to `Process`. When set, `ProxyRequest` delegates directly to this handler instead of going through the reverse proxy. No subprocess, no health checks, no TCP roundtrip.
|
||||||
|
|
||||||
|
**2a. Add testHandler to Process:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// In Process struct (process.go):
|
||||||
|
testHandler http.Handler // set only in tests; bypasses subprocess and reverse proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
In `Process.Start()`, skip subprocess + health check when handler is set:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func (p *Process) start() error {
|
||||||
|
if p.testHandler != nil {
|
||||||
|
p.setState(StateReady)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// existing subprocess logic...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
In `Process.ProxyRequest()`, delegate directly to the handler:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Before the reverseProxy.ServeHTTP call:
|
||||||
|
if p.testHandler != nil {
|
||||||
|
p.testHandler.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**2b. Test helper to create the handler:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// newTestHandler returns an http.Handler that mimics llama.cpp's API
|
||||||
|
// (same endpoints as simple-responder).
|
||||||
|
func newTestHandler(respond string) http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { ... })
|
||||||
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { ... })
|
||||||
|
// ... other endpoints
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Tests for routing/auth/CORS/streaming then become:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestProxyManager_AuthRequired(t *testing.T) {
|
||||||
|
handler := newTestHandler("model1")
|
||||||
|
|
||||||
|
config := testConfigFromYAML(t, `
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
logLevel: error
|
||||||
|
requiredAPIKeys: [test-key]
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
|
||||||
|
`)
|
||||||
|
pm := NewProxyManager(config)
|
||||||
|
// inject handler — skips subprocess, health check, port allocation
|
||||||
|
pm.processGroups["model1"].process.testHandler = handler
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why this matters:** The handler is called directly in-process. No subprocess spawn, no health check timeout, no port allocation, no TCP roundtrip, no reverse proxy overhead. Routing tests go from ~100ms each (process startup + health check) to ~1ms. Unlike an `httptest.Server` approach, there are zero network hops.
|
||||||
|
|
||||||
|
**Why not blank-cmd + proxy URL:** A blank `cmd` with a `proxy` field pointing at `httptest.Server` still requires a real TCP roundtrip through the reverse proxy and introduces "external process" semantics to the config schema. Injecting the handler directly keeps it purely a test concern with no config changes.
|
||||||
|
|
||||||
|
**Scope:** Most tests in `proxymanager_test.go` (auth, CORS, model listing, streaming, peer proxy), `peerproxy_test.go`, `metrics_monitor_test.go`.
|
||||||
|
|
||||||
|
### Stage 3: Migrate tests incrementally
|
||||||
|
|
||||||
|
**Goal:** Convert existing tests to use the Stage 1 + Stage 2 helpers.
|
||||||
|
|
||||||
|
**Effort:** Medium | **Impact:** Cleaner, more reliable tests | **Risk:** None
|
||||||
|
|
||||||
|
Priority order:
|
||||||
|
1. `proxymanager_test.go` routing tests (highest count, most repetition)
|
||||||
|
2. `peerproxy_test.go` (straightforward, all HTTP routing)
|
||||||
|
3. `metrics_monitor_test.go` (capture logic doesn't need real processes)
|
||||||
|
4. `processgroup_test.go` swap tests (keep simple-responder for actual swap lifecycle tests)
|
||||||
|
|
||||||
|
Tests that **must keep simple-responder:**
|
||||||
|
- Process lifecycle: start/stop, SIGKILL, SIGTERM, TTL expiry, health check failures, failed start counting
|
||||||
|
- ProcessGroup swap concurrency (the port-collision test in `TestProcessGroup_ProxyRequestSwapIsTrueParallel`)
|
||||||
|
|
||||||
|
**Scope:** ~60-70% of tests can drop simple-responder.
|
||||||
|
|
||||||
|
### Stage 4 (optional): Process interface for ProcessGroup
|
||||||
|
|
||||||
|
**Goal:** Enable pure unit tests of ProcessGroup's swap/exclusive/concurrency logic without any HTTP server at all.
|
||||||
|
|
||||||
|
**Effort:** High | **Impact:** Pure unit tests possible | **Risk:** Medium (refactor core code)
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ProcessController interface {
|
||||||
|
Start() error
|
||||||
|
Stop(StopStrategy)
|
||||||
|
ProxyRequest(http.ResponseWriter, *http.Request) error
|
||||||
|
CurrentState() ProcessState
|
||||||
|
ID() string
|
||||||
|
SetState(ProcessState) // for test setup
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This requires:
|
||||||
|
- Extracting the interface
|
||||||
|
- A `MockProcess` implementation
|
||||||
|
- Refactoring `ProcessGroup` to use the interface instead of `*Process`
|
||||||
|
|
||||||
|
**Recommendation:** Only do this if ProcessGroup grows significantly more complex. Stages 1-3 give 80% of the benefit for 20% of the effort.
|
||||||
|
|
||||||
|
## Effort/Impact Summary
|
||||||
|
|
||||||
|
| Stage | Effort | Impact | Risk |
|
||||||
|
|-------|--------|--------|------|
|
||||||
|
| 1. YAML config helper | Low | Config bugs caught earlier | None |
|
||||||
|
| 2. Injected test handler | Medium | 10-100x faster routing tests | Low |
|
||||||
|
| 3. Migrate tests | Medium | Cleaner, more reliable tests | None |
|
||||||
|
| 4. Process interface | High | Pure unit tests possible | Medium |
|
||||||
|
|
||||||
|
**Recommended approach:** Do stages 1-3 in order. Each stage is independently valuable and can ship on its own. Stage 4 is deferred unless there's a specific need.
|
||||||
@@ -210,6 +210,11 @@ func main() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.GET("/v1/audio/voices", func(c *gin.Context) {
|
||||||
|
model := c.Query("model")
|
||||||
|
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
|
||||||
|
})
|
||||||
|
|
||||||
r.GET("/slow-respond", func(c *gin.Context) {
|
r.GET("/slow-respond", func(c *gin.Context) {
|
||||||
echo := c.Query("echo")
|
echo := c.Query("echo")
|
||||||
delay := c.Query("delay")
|
delay := c.Query("delay")
|
||||||
@@ -269,6 +274,43 @@ func main() {
|
|||||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// SD API endpoints
|
||||||
|
r.POST("/sdapi/v1/txt2img", func(c *gin.Context) {
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Request.Body.Close()
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"model": modelName,
|
||||||
|
"images": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
r.POST("/sdapi/v1/img2img", func(c *gin.Context) {
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Request.Body.Close()
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"model": modelName,
|
||||||
|
"images": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
r.GET("/sdapi/v1/loras", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"loras": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
address := "127.0.0.1:" + *port // Address with the specified port
|
address := "127.0.0.1:" + *port // Address with the specified port
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
|
|||||||
@@ -39,6 +39,49 @@
|
|||||||
},
|
},
|
||||||
"default": {},
|
"default": {},
|
||||||
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
|
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
|
||||||
|
},
|
||||||
|
"timeouts": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"connect": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 30,
|
||||||
|
"description": "TCP connection timeout in seconds. Set to 0 to disable."
|
||||||
|
},
|
||||||
|
"keepalive": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 30,
|
||||||
|
"description": "TCP keepalive timeout in seconds. Set to 0 to disable."
|
||||||
|
},
|
||||||
|
"responseHeader": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 0,
|
||||||
|
"description": "Time to wait for response headers in seconds. Set to 0 to disable."
|
||||||
|
},
|
||||||
|
"tlsHandshake": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 10,
|
||||||
|
"description": "TLS handshake timeout in seconds. Set to 0 to disable."
|
||||||
|
},
|
||||||
|
"expectContinue": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 1,
|
||||||
|
"description": "Expect-Continue timeout in seconds. Set to 0 to disable."
|
||||||
|
},
|
||||||
|
"idleConn": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 90,
|
||||||
|
"description": "Idle connection timeout in seconds. Set to 0 to disable."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"description": "Timeout settings for proxy connections."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -48,6 +91,12 @@
|
|||||||
"default": 120,
|
"default": 120,
|
||||||
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
||||||
},
|
},
|
||||||
|
"globalTTL": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 0,
|
||||||
|
"description": "Default TTL for all models in seconds, 0 means no TTL and models will never be automatically unloaded"
|
||||||
|
},
|
||||||
"logLevel": {
|
"logLevel": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
@@ -87,6 +136,12 @@
|
|||||||
"default": 1000,
|
"default": 1000,
|
||||||
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
||||||
},
|
},
|
||||||
|
"captureBuffer": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 5,
|
||||||
|
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
|
||||||
|
},
|
||||||
"startPort": {
|
"startPort": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"default": 5800,
|
"default": 5800,
|
||||||
@@ -171,9 +226,9 @@
|
|||||||
},
|
},
|
||||||
"ttl": {
|
"ttl": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"minimum": 0,
|
"minimum": -1,
|
||||||
"default": 0,
|
"default": -1,
|
||||||
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
|
"description": "Automatically unload the model after ttl seconds. -1 uses the global TTL value, 0 disables unloading. Must be >0 to enable."
|
||||||
},
|
},
|
||||||
"useModelName": {
|
"useModelName": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -188,11 +243,26 @@
|
|||||||
"default": "",
|
"default": "",
|
||||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||||
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||||
|
},
|
||||||
|
"setParams": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true,
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||||
|
},
|
||||||
|
"setParamsByID": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true
|
||||||
|
},
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary mapping requested model IDs (or aliases) to parameters to set/override in requests. Applied after setParams and can override those values. Useful with aliases to vary behaviour depending on which alias the client used (e.g. different reasoning_effort per alias). Keys support ${MODEL_ID} macro substitution. Protected params like 'model' cannot be overridden."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"default": {},
|
"default": {},
|
||||||
"description": "Dictionary of filter settings. Only stripParams is supported."
|
"description": "Dictionary of filter settings. Supports stripParams, setParams, and setParamsByID."
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -214,6 +284,9 @@
|
|||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": false,
|
"default": false,
|
||||||
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
|
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
|
||||||
|
},
|
||||||
|
"timeouts": {
|
||||||
|
"$ref": "#/definitions/timeouts"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -252,6 +325,44 @@
|
|||||||
},
|
},
|
||||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||||
},
|
},
|
||||||
|
"matrix": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||||
|
"required": [
|
||||||
|
"vars",
|
||||||
|
"sets"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"vars": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||||
|
"minProperties": 1,
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"propertyNames": {
|
||||||
|
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"evict_costs": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sets": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||||
|
"minProperties": 1,
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
},
|
||||||
"hooks": {
|
"hooks": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -273,6 +384,137 @@
|
|||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||||
|
},
|
||||||
|
"logToStdout": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"proxy",
|
||||||
|
"upstream",
|
||||||
|
"both",
|
||||||
|
"none"
|
||||||
|
],
|
||||||
|
"default": "proxy",
|
||||||
|
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
|
||||||
|
},
|
||||||
|
"apiKeys": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1
|
||||||
|
},
|
||||||
|
"default": [],
|
||||||
|
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
|
||||||
|
},
|
||||||
|
"peers": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"proxy",
|
||||||
|
"models"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"proxy": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "uri",
|
||||||
|
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
|
||||||
|
},
|
||||||
|
"apiKey": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1
|
||||||
|
},
|
||||||
|
"description": "A list of models served by the peer."
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stripParams": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||||
|
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
|
||||||
|
},
|
||||||
|
"setParams": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true,
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
|
||||||
|
},
|
||||||
|
"timeouts": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"connect": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 30,
|
||||||
|
"description": "TCP connection timeout in seconds."
|
||||||
|
},
|
||||||
|
"keepalive": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 30,
|
||||||
|
"description": "TCP keepalive connection timeout in seconds."
|
||||||
|
},
|
||||||
|
"responseHeader": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 0,
|
||||||
|
"description": "Time to wait for response headers in seconds."
|
||||||
|
},
|
||||||
|
"tlsHandshake": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 10,
|
||||||
|
"description": "TLS handshake timeout in seconds."
|
||||||
|
},
|
||||||
|
"idleConn": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 90,
|
||||||
|
"description": "Idle connection timeout in seconds."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"description": "Timeout settings for proxy connections to this peer."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"default": {},
|
||||||
|
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
"allOf": [
|
||||||
|
{
|
||||||
|
"if": {
|
||||||
|
"required": ["groups"]
|
||||||
|
},
|
||||||
|
"then": {
|
||||||
|
"not": {
|
||||||
|
"required": ["matrix"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"if": {
|
||||||
|
"required": ["matrix"]
|
||||||
|
},
|
||||||
|
"then": {
|
||||||
|
"not": {
|
||||||
|
"required": ["groups"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -34,12 +34,27 @@ logLevel: info
|
|||||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||||
logTimeFormat: ""
|
logTimeFormat: ""
|
||||||
|
|
||||||
|
# logToStdout: controls what is logged to stdout
|
||||||
|
# - optional, default: "proxy"
|
||||||
|
# - valid values:
|
||||||
|
# - "proxy": logs generated by llama-swap when swapping models,
|
||||||
|
# handling requests, etc.
|
||||||
|
# - "upstream": a copy of an upstream processes stdout logs
|
||||||
|
# - "both": both the proxy and upstream logs interleaved together
|
||||||
|
# - "none": no logs are ever written to stdout
|
||||||
|
logToStdout: "proxy"
|
||||||
|
|
||||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||||
# - optional, default: 1000
|
# - optional, default: 1000
|
||||||
# - controls how many metrics are stored in memory before older ones are discarded
|
# - controls how many metrics are stored in memory before older ones are discarded
|
||||||
# - useful for limiting memory usage when processing large volumes of metrics
|
# - useful for limiting memory usage when processing large volumes of metrics
|
||||||
metricsMaxInMemory: 1000
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
|
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||||
|
# - optional, default: 10
|
||||||
|
# - set to 0 to disable
|
||||||
|
captureBuffer: 15
|
||||||
|
|
||||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||||
# - optional, default: 5800
|
# - optional, default: 5800
|
||||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||||
@@ -60,6 +75,11 @@ sendLoadingState: true
|
|||||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||||
includeAliasesInList: false
|
includeAliasesInList: false
|
||||||
|
|
||||||
|
# globalTTL: the default TTL in seconds before unloading a model
|
||||||
|
# - optional, default: 0 (never automatically unload)
|
||||||
|
# - must be >= 0
|
||||||
|
globalTTL: 0
|
||||||
|
|
||||||
# macros: a dictionary of string substitutions
|
# macros: a dictionary of string substitutions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - macros are reusable snippets
|
# - macros are reusable snippets
|
||||||
@@ -70,6 +90,9 @@ includeAliasesInList: false
|
|||||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||||
# - macro values can be numbers, bools, or strings
|
# - macro values can be numbers, bools, or strings
|
||||||
# - macros can contain other macros, but they must be defined before they are used
|
# - macros can contain other macros, but they must be defined before they are used
|
||||||
|
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||||
|
# - env macros are substituted first, before regular macros
|
||||||
|
# - if the env var is not set, config loading will fail with an error
|
||||||
macros:
|
macros:
|
||||||
# Example of a multi-line macro
|
# Example of a multi-line macro
|
||||||
"latest-llama": >
|
"latest-llama": >
|
||||||
@@ -82,6 +105,24 @@ macros:
|
|||||||
# but they must be previously declared.
|
# but they must be previously declared.
|
||||||
"default_args": "--ctx-size ${default_ctx}"
|
"default_args": "--ctx-size ${default_ctx}"
|
||||||
|
|
||||||
|
# Example of environment variable macros
|
||||||
|
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||||
|
# - useful for paths, secrets, or machine-specific configuration
|
||||||
|
"models_dir": "${env.HOME}/models"
|
||||||
|
|
||||||
|
# apiKeys: require an API key when making requests to inference endpoints
|
||||||
|
# - optional, default: []
|
||||||
|
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||||
|
# - each key is a non-empty string
|
||||||
|
apiKeys:
|
||||||
|
- "sk-hunter2"
|
||||||
|
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||||
|
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||||
|
|
||||||
|
# use environment variable macros to keep secrets out of the config
|
||||||
|
- "${env.API_KEY_1}"
|
||||||
|
- "${env.API_KEY_2}"
|
||||||
|
|
||||||
# models: a dictionary of model configurations
|
# models: a dictionary of model configurations
|
||||||
# - required
|
# - required
|
||||||
# - each key is the model's ID, used in API requests
|
# - each key is the model's ID, used in API requests
|
||||||
@@ -90,7 +131,7 @@ macros:
|
|||||||
# - below are examples of the all the settings a model can have
|
# - below are examples of the all the settings a model can have
|
||||||
models:
|
models:
|
||||||
# keys are the model names used in API requests
|
# keys are the model names used in API requests
|
||||||
"llama":
|
"gpt-oss-120b":
|
||||||
# macros: a dictionary of string substitutions specific to this model
|
# macros: a dictionary of string substitutions specific to this model
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - macros defined here override macros defined in the global macros section
|
# - macros defined here override macros defined in the global macros section
|
||||||
@@ -107,7 +148,7 @@ models:
|
|||||||
cmd: |
|
cmd: |
|
||||||
# ${latest-llama} is a macro that is defined above
|
# ${latest-llama} is a macro that is defined above
|
||||||
${latest-llama}
|
${latest-llama}
|
||||||
--model path/to/llama-8B-Q4_K_M.gguf
|
--model path/to/gpt-oss-120B.gguf
|
||||||
--ctx-size ${default_ctx}
|
--ctx-size ${default_ctx}
|
||||||
--temperature ${temp}
|
--temperature ${temp}
|
||||||
|
|
||||||
@@ -115,13 +156,13 @@ models:
|
|||||||
# - optional, default: empty string
|
# - optional, default: empty string
|
||||||
# - if set, it will be used in the v1/models API response
|
# - if set, it will be used in the v1/models API response
|
||||||
# - if not set, it will be omitted in the JSON model record
|
# - if not set, it will be omitted in the JSON model record
|
||||||
name: "llama 3.1 8B"
|
name: "gpt-oss 120B"
|
||||||
|
|
||||||
# description: a description for the model
|
# description: a description for the model
|
||||||
# - optional, default: empty string
|
# - optional, default: empty string
|
||||||
# - if set, it will be used in the v1/models API response
|
# - if set, it will be used in the v1/models API response
|
||||||
# - if not set, it will be omitted in the JSON model record
|
# - if not set, it will be omitted in the JSON model record
|
||||||
description: "A small but capable model used for quick testing"
|
description: "A thinking model from OpenAI"
|
||||||
|
|
||||||
# env: define an array of environment variables to inject into cmd's environment
|
# env: define an array of environment variables to inject into cmd's environment
|
||||||
# - optional, default: empty array
|
# - optional, default: empty array
|
||||||
@@ -136,14 +177,6 @@ models:
|
|||||||
# - if you use a custom port in cmd this *must* be set
|
# - if you use a custom port in cmd this *must* be set
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# aliases: alternative model names that this model configuration is used for
|
|
||||||
# - optional, default: empty array
|
|
||||||
# - aliases must be unique globally
|
|
||||||
# - useful for impersonating a specific model
|
|
||||||
aliases:
|
|
||||||
- "gpt-4o-mini"
|
|
||||||
- "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
# checkEndpoint: URL path to check if the server is ready
|
# checkEndpoint: URL path to check if the server is ready
|
||||||
# - optional, default: /health
|
# - optional, default: /health
|
||||||
# - endpoint is expected to return an HTTP 200 response
|
# - endpoint is expected to return an HTTP 200 response
|
||||||
@@ -152,8 +185,10 @@ models:
|
|||||||
checkEndpoint: /custom-endpoint
|
checkEndpoint: /custom-endpoint
|
||||||
|
|
||||||
# ttl: automatically unload the model after ttl seconds
|
# ttl: automatically unload the model after ttl seconds
|
||||||
# - optional, default: 0
|
# - optional, default: -1 (use global default)
|
||||||
# - ttl values must be a value greater than 0
|
# - ttl values must be a value greater than or equal to 0
|
||||||
|
# - a ttl of -1 will use the global TTL value as the default
|
||||||
|
# - a ttl of 0 will mean never unload
|
||||||
# - a value of 0 disables automatic unloading of the model
|
# - a value of 0 disables automatic unloading of the model
|
||||||
ttl: 60
|
ttl: 60
|
||||||
|
|
||||||
@@ -161,11 +196,11 @@ models:
|
|||||||
# - optional, default: ""
|
# - optional, default: ""
|
||||||
# - useful for when the upstream server expects a specific model name that
|
# - useful for when the upstream server expects a specific model name that
|
||||||
# is different from the model's ID
|
# is different from the model's ID
|
||||||
useModelName: "qwen:qwq"
|
useModelName: "openai/gpt-oss-120B"
|
||||||
|
|
||||||
# filters: a dictionary of filter settings
|
# filters: a dictionary of filter settings
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - only stripParams is currently supported
|
# - same capabilities as peer filters (stripParams, setParams)
|
||||||
filters:
|
filters:
|
||||||
# stripParams: a comma separated list of parameters to remove from the request
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
# - optional, default: ""
|
# - optional, default: ""
|
||||||
@@ -175,6 +210,43 @@ models:
|
|||||||
# - recommended to stick to sampling parameters
|
# - recommended to stick to sampling parameters
|
||||||
stripParams: "temperature, top_p, top_k"
|
stripParams: "temperature, top_p, top_k"
|
||||||
|
|
||||||
|
# setParams: a dictionary of parameters to set/override in requests
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - useful for enforcing specific parameter values
|
||||||
|
# - protected params like "model" cannot be overridden
|
||||||
|
# - values can be strings, numbers, booleans, arrays, or objects
|
||||||
|
# - always runs for the model
|
||||||
|
setParams:
|
||||||
|
# Example: enforce specific sampling parameters
|
||||||
|
temperature: 0.7
|
||||||
|
top_p: 0.9
|
||||||
|
|
||||||
|
# setParamsByID: a dictionary of parameters to set based the model ID
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - combine with aliases to create variant behaviour without reloading the model
|
||||||
|
# - parameters are set in the request body JSON
|
||||||
|
# - run after setParams so it will override any settings
|
||||||
|
# - protected params like "model" cannot be overridden
|
||||||
|
# - values can be strings, numbers, booleans, arrays, or objects
|
||||||
|
# - model aliases will be automatically created for each key
|
||||||
|
setParamsByID:
|
||||||
|
"${MODEL_ID}":
|
||||||
|
chat_template_kwargs:
|
||||||
|
reasoning_effort: medium
|
||||||
|
"${MODEL_ID}:high":
|
||||||
|
chat_template_kwargs:
|
||||||
|
reasoning_effort: high
|
||||||
|
"${MODEL_ID}:low":
|
||||||
|
chat_template_kwargs:
|
||||||
|
reasoning_effort: low
|
||||||
|
|
||||||
|
# aliases: alternative model names that this model configuration is used for
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - aliases must be unique globally
|
||||||
|
# - useful for impersonating a specific model
|
||||||
|
aliases:
|
||||||
|
- "gpt-4o-mini"
|
||||||
|
|
||||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - while metadata can contains complex types it is recommended to keep it simple
|
# - while metadata can contains complex types it is recommended to keep it simple
|
||||||
@@ -212,6 +284,22 @@ models:
|
|||||||
# - optional, default: undefined (use global setting)
|
# - optional, default: undefined (use global setting)
|
||||||
sendLoadingState: false
|
sendLoadingState: false
|
||||||
|
|
||||||
|
# timeouts: configure proxy connection timeouts for this model
|
||||||
|
# - optional, defaults shown below
|
||||||
|
# - useful for models running on slower hardware that need longer timeouts
|
||||||
|
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
|
||||||
|
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
|
||||||
|
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
|
||||||
|
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
|
||||||
|
# - idleConn: idle connection timeout in seconds, default: 90 seconds
|
||||||
|
# - set any value to 0 to disable that timeout (not recommended)
|
||||||
|
timeouts:
|
||||||
|
connect: 30
|
||||||
|
keepalive: 0
|
||||||
|
responseHeader: 60
|
||||||
|
tlsHandshake: 10
|
||||||
|
idleConn: 90
|
||||||
|
|
||||||
# Unlisted model example:
|
# Unlisted model example:
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
# unlisted: boolean, true or false
|
# unlisted: boolean, true or false
|
||||||
@@ -243,68 +331,83 @@ models:
|
|||||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||||
cmdStop: docker stop ${MODEL_ID}
|
cmdStop: docker stop ${MODEL_ID}
|
||||||
|
|
||||||
# groups: a dictionary of group settings
|
# =============================================================================
|
||||||
# - optional, default: empty dictionary
|
# matrix: run concurrent models with a solver-based swap DSL
|
||||||
# - provides advanced controls over model swapping behaviour
|
# =============================================================================
|
||||||
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
|
||||||
# - model IDs must be defined in the Models section
|
|
||||||
# - a model can only be a member of one group
|
|
||||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
|
||||||
# - see issue #109 for details
|
|
||||||
#
|
#
|
||||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
# Note:
|
||||||
groups:
|
# A config must use either a matrix or legacy groups, not both. A configuration error
|
||||||
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
# will occur if both are defined. Configuration examples for legacy Groups can be found:
|
||||||
# to run a time across the whole llama-swap instance
|
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||||
"group1":
|
#
|
||||||
# swap: controls the model swapping behaviour in within the group
|
# The matrix declares valid combinations of models that can run concurrently.
|
||||||
# - optional, default: true
|
# When a model is requested, the solver finds the cheapest way to make it
|
||||||
# - true : only one model is allowed to run at a time
|
# available by evicting as few (and least costly) running models as possible.
|
||||||
# - false: all models can run together, no swapping
|
#
|
||||||
swap: true
|
# Solver behavior:
|
||||||
|
# 1. Request arrives for model X
|
||||||
|
# 2. If X is already running, forward immediately. Done.
|
||||||
|
# 3. Find all sets containing X
|
||||||
|
# 4. For each candidate set, compute cost: sum of evict_costs for
|
||||||
|
# every running model NOT in that set
|
||||||
|
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
||||||
|
# 6. Evict what needs to stop. Start X. Forward request.
|
||||||
|
#
|
||||||
|
# Subset semantics: a set [a, b, c] means any subset is valid.
|
||||||
|
# Only the requested model is started — others are not preloaded.
|
||||||
|
#
|
||||||
|
# A model not appearing in any set can only run alone.
|
||||||
|
#
|
||||||
|
matrix:
|
||||||
|
# vars: short names for models (alphanumeric, 1-8 chars)
|
||||||
|
# - required for sets and evict_costs settings
|
||||||
|
# - each entry is a short name to a real model ID. Do not use an alias
|
||||||
|
# - used to keep set DSL logic short and easier to read
|
||||||
|
# - sets and evict_costs only use identifiers defined in vars
|
||||||
|
vars:
|
||||||
|
g: gemma-model
|
||||||
|
q: qwen-model
|
||||||
|
m: mistral-model
|
||||||
|
v: voxtral-model
|
||||||
|
e: reranker-model
|
||||||
|
L: llama-70B
|
||||||
|
sd: stable-diffusion
|
||||||
|
|
||||||
# exclusive: controls how the group affects other groups
|
# evict_costs: relative cost of losing a running model (default: 1)
|
||||||
# - optional, default: true
|
evict_costs:
|
||||||
# - true: causes all other groups to unload when this group runs a model
|
v: 50 # vllm backend, slow cold start
|
||||||
# - false: does not affect other groups
|
L: 30 # 70B weights, slow to load
|
||||||
exclusive: true
|
|
||||||
|
|
||||||
# members references the models defined above
|
# sets: named sets of concurrent model combinations
|
||||||
# required
|
# Values are DSL strings with operators:
|
||||||
members:
|
# & AND (models run together)
|
||||||
- "llama"
|
# | OR (alternatives)
|
||||||
- "qwen-unlisted"
|
# () grouping
|
||||||
|
# +ref inline another set's expression
|
||||||
|
#
|
||||||
|
# Expansion examples:
|
||||||
|
# "L" → [L]
|
||||||
|
# "a & b" → [a, b]
|
||||||
|
# "a | b" → [a], [b]
|
||||||
|
# "(a | b) & c" → [a, c], [b, c]
|
||||||
|
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||||
|
# "+llms & v" → expands llms inline, then applies & v
|
||||||
|
sets:
|
||||||
|
# LLM + TTS: switching between g/q/m won't evict v
|
||||||
|
# expands to: [g,v], [q,v], [m,v]
|
||||||
|
standard: "(g | q | m) & v"
|
||||||
|
|
||||||
# Example:
|
# LLM + TTS + reranker
|
||||||
# - in group2 all models can run at the same time
|
# expands to: [g,v,e], [q,v,e]
|
||||||
# - when a different group is loaded it causes all running models in this group to unload
|
with_rerank: "(g | q) & v & e"
|
||||||
"group2":
|
|
||||||
swap: false
|
|
||||||
|
|
||||||
# exclusive: false does not unload other groups when a model in group2 is requested
|
# LLM + image generation, no TTS
|
||||||
# - the models in group2 will be loaded but will not unload any other groups
|
# expands to: [g,sd], [q,sd]
|
||||||
exclusive: false
|
creative: "(g | q) & sd"
|
||||||
members:
|
|
||||||
- "docker-llama"
|
|
||||||
- "modelA"
|
|
||||||
- "modelB"
|
|
||||||
|
|
||||||
# Example:
|
# 70B model uses all GPUs, can only run alone
|
||||||
# - a persistent group, prevents other groups from unloading it
|
# expands to: [L]
|
||||||
"forever":
|
full: "L"
|
||||||
# persistent: prevents over groups from unloading the models in this group
|
|
||||||
# - optional, default: false
|
|
||||||
# - does not affect individual model behaviour
|
|
||||||
persistent: true
|
|
||||||
|
|
||||||
# set swap/exclusive to false to prevent swapping inside the group
|
|
||||||
# and the unloading of other groups
|
|
||||||
swap: false
|
|
||||||
exclusive: false
|
|
||||||
members:
|
|
||||||
- "forever-modelA"
|
|
||||||
- "forever-modelB"
|
|
||||||
- "forever-modelc"
|
|
||||||
|
|
||||||
# hooks: a dictionary of event triggers and actions
|
# hooks: a dictionary of event triggers and actions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
@@ -321,3 +424,67 @@ hooks:
|
|||||||
# otherwise models will be loaded and swapped out
|
# otherwise models will be loaded and swapped out
|
||||||
preload:
|
preload:
|
||||||
- "llama"
|
- "llama"
|
||||||
|
|
||||||
|
# peers: a dictionary of remote peers and models they provide
|
||||||
|
# - optional, default empty dictionary
|
||||||
|
# - peers can be another llama-swap
|
||||||
|
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||||
|
peers:
|
||||||
|
# keys is the peer'd ID
|
||||||
|
llama-swap-peer:
|
||||||
|
# proxy: a valid base URL to proxy requests to
|
||||||
|
# - required
|
||||||
|
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||||
|
proxy: http://192.168.1.23
|
||||||
|
# models: a list of models served by the peer
|
||||||
|
# - required
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
- model_b
|
||||||
|
- embeddings/model_c
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
# apiKey: a string key to be injected into the request
|
||||||
|
# - optional, default: ""
|
||||||
|
# - if blank, no key will be added to the request
|
||||||
|
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||||
|
# - can be a string or a macro
|
||||||
|
apiKey: ${env.OPENROUTER_API_KEY}
|
||||||
|
models:
|
||||||
|
- meta-llama/llama-3.1-8b-instruct
|
||||||
|
- qwen/qwen3-235b-a22b-2507
|
||||||
|
- deepseek/deepseek-v3.2
|
||||||
|
- z-ai/glm-4.7
|
||||||
|
- moonshotai/kimi-k2-0905
|
||||||
|
- minimax/minimax-m2.1
|
||||||
|
# timeouts: configure proxy connection timeouts for this peer
|
||||||
|
# - optional, defaults shown below
|
||||||
|
# - useful when the peer runs on slower hardware
|
||||||
|
# - set any value to 0 to disable that timeout (not recommended)
|
||||||
|
timeouts:
|
||||||
|
connect: 30
|
||||||
|
keepalive: 30
|
||||||
|
responseHeader: 60
|
||||||
|
tlsHandshake: 10
|
||||||
|
idleConn: 90
|
||||||
|
|
||||||
|
# filters: a dictionary of filter settings for peer requests
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - same capabilities as model filters (stripParams, setParams)
|
||||||
|
filters:
|
||||||
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
|
# - optional, default: ""
|
||||||
|
# - useful for removing parameters that the peer doesn't support
|
||||||
|
# - the `model` parameter can never be removed
|
||||||
|
stripParams: "temperature, top_p"
|
||||||
|
|
||||||
|
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - useful for injecting provider-specific settings like data retention policies
|
||||||
|
# - protected params like "model" cannot be overridden
|
||||||
|
# - values can be strings, numbers, booleans, arrays, or objects
|
||||||
|
setParams:
|
||||||
|
# Example: enforce zero-data-retention for OpenRouter
|
||||||
|
provider:
|
||||||
|
data_collection: "deny"
|
||||||
|
zdr: true
|
||||||
|
|||||||
@@ -1,28 +1,50 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
cd $(dirname "$0")
|
cd $(dirname "$0")
|
||||||
|
|
||||||
|
# use this to test locally, example:
|
||||||
|
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
|
||||||
|
# you need read:package scope on the token. Generate a personal access token with
|
||||||
|
# the scopes: gist, read:org, repo, write:packages
|
||||||
|
# then: gh auth login (and copy/paste the new token)
|
||||||
|
|
||||||
|
LOG_DEBUG=${LOG_DEBUG:-0}
|
||||||
|
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
|
||||||
|
|
||||||
|
log_debug() {
|
||||||
|
if [ "$LOG_DEBUG" = "1" ]; then
|
||||||
|
echo "[DEBUG] $*"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
log_info() {
|
||||||
|
echo "[INFO] $*"
|
||||||
|
}
|
||||||
|
|
||||||
ARCH=$1
|
ARCH=$1
|
||||||
PUSH_IMAGES=${2:-false}
|
PUSH_IMAGES=${2:-false}
|
||||||
|
|
||||||
# List of allowed architectures
|
# List of allowed architectures
|
||||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cuda13" "cpu" "rocm")
|
||||||
|
|
||||||
# Check if ARCH is in the allowed list
|
# Check if ARCH is in the allowed list
|
||||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check if GITHUB_TOKEN is set and not empty
|
# Check if GITHUB_TOKEN is set and not empty
|
||||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
|
||||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
log_info "Error: GITHUB_TOKEN is not set or is empty."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
||||||
# variable, this permits testing with forked llama.cpp repositories
|
# variable, this permits testing with forked llama.cpp repositories
|
||||||
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
||||||
|
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
|
||||||
|
|
||||||
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
||||||
# to enable easy container builds on forked repos
|
# to enable easy container builds on forked repos
|
||||||
@@ -32,25 +54,76 @@ LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
|||||||
# have to strip out the 'v' due to .tar.gz file naming
|
# have to strip out the 'v' due to .tar.gz file naming
|
||||||
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||||
|
|
||||||
|
# Fetches the most recent llama.cpp tag matching the given prefix
|
||||||
|
# Handles pagination to search beyond the first 100 results
|
||||||
|
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
|
||||||
|
# Returns: the version number extracted from the tag
|
||||||
|
fetch_llama_tag() {
|
||||||
|
local tag_prefix=$1
|
||||||
|
local page=1
|
||||||
|
local per_page=100
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
log_debug "Fetching page $page for tag prefix: $tag_prefix"
|
||||||
|
|
||||||
|
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||||
|
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
|
||||||
|
|
||||||
|
# Check for API errors
|
||||||
|
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
|
||||||
|
local error_msg=$(echo "$response" | jq -r '.message')
|
||||||
|
log_info "GitHub API error: $error_msg"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if response is empty array (no more pages)
|
||||||
|
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
|
||||||
|
log_debug "No more pages (empty response)"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract matching tag from this page
|
||||||
|
local found_tag=$(echo "$response" | jq -r \
|
||||||
|
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
|
||||||
|
| sort -r | head -n1)
|
||||||
|
|
||||||
|
if [ -n "$found_tag" ]; then
|
||||||
|
log_debug "Found tag: $found_tag on page $page"
|
||||||
|
echo "$found_tag" | awk -F '-' '{print $NF}'
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
page=$((page + 1))
|
||||||
|
|
||||||
|
# Safety limit to prevent infinite loops
|
||||||
|
if [ $page -gt 50 ]; then
|
||||||
|
log_info "Reached pagination safety limit (50 pages)"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
if [ "$ARCH" == "cpu" ]; then
|
if [ "$ARCH" == "cpu" ]; then
|
||||||
# cpu only containers just use the server tag
|
LCPP_TAG=$(fetch_llama_tag "server")
|
||||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
|
||||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
|
||||||
| jq -r '.[] | select(.metadata.container.tags[] | startswith("server")) | .metadata.container.tags[]' \
|
|
||||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
|
||||||
BASE_TAG=server-${LCPP_TAG}
|
BASE_TAG=server-${LCPP_TAG}
|
||||||
else
|
else
|
||||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
|
||||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
|
||||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
|
||||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
|
||||||
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
SD_TAG=master-${ARCH}
|
||||||
|
|
||||||
# Abort if LCPP_TAG is empty.
|
# Abort if LCPP_TAG is empty.
|
||||||
if [[ -z "$LCPP_TAG" ]]; then
|
if [[ -z "$LCPP_TAG" ]]; then
|
||||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
log_info "Abort: Could not find llama-server container for arch: $ARCH"
|
||||||
exit 1
|
exit 1
|
||||||
|
else
|
||||||
|
log_info "LCPP_TAG: $LCPP_TAG"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
|
||||||
|
log_info "Abort: DEBUG_ABORT_BUILD set"
|
||||||
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
for CONTAINER_TYPE in non-root root; do
|
for CONTAINER_TYPE in non-root root; do
|
||||||
@@ -68,10 +141,22 @@ for CONTAINER_TYPE in non-root root; do
|
|||||||
USER_HOME=/app
|
USER_HOME=/app
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||||
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||||
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||||
|
|
||||||
|
# For architectures with stable-diffusion.cpp support, layer sd-server on top
|
||||||
|
case "$ARCH" in
|
||||||
|
"musa" | "vulkan")
|
||||||
|
log_info "Adding sd-server to $CONTAINER_TAG"
|
||||||
|
docker build --provenance=false -f llama-swap-sd.Containerfile \
|
||||||
|
--build-arg BASE=${CONTAINER_TAG} \
|
||||||
|
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
|
||||||
|
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
|
||||||
|
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
|
||||||
|
esac
|
||||||
|
|
||||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||||
docker push ${CONTAINER_TAG}
|
docker push ${CONTAINER_TAG}
|
||||||
docker push ${CONTAINER_LATEST}
|
docker push ${CONTAINER_LATEST}
|
||||||
|
|||||||
@@ -0,0 +1,305 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Build script for llama-swap-docker with commit hash pinning
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./build-image.sh --cuda # Build CUDA image
|
||||||
|
# ./build-image.sh --vulkan # Build Vulkan image
|
||||||
|
# ./build-image.sh --cuda --no-cache # Build CUDA image without cache
|
||||||
|
# LLAMA_COMMIT_HASH=abc123 ./build-image.sh --cuda # Override llama.cpp commit
|
||||||
|
# LLAMA_COMMIT_HASH=b8429 ./build-image.sh --vulkan # Override llama.cpp release tag (vulkan uses prebuilt binaries)
|
||||||
|
# WHISPER_COMMIT_HASH=def456 ./build-image.sh --vulkan # Override whisper.cpp commit
|
||||||
|
# SD_COMMIT_HASH=ghi789 ./build-image.sh --cuda # Override stable-diffusion.cpp commit
|
||||||
|
#
|
||||||
|
# Features:
|
||||||
|
# - Auto-detects latest commit hashes from git repos
|
||||||
|
# - Builds llama-swap from local source code
|
||||||
|
# - Allows environment variable overrides for reproducible builds
|
||||||
|
# - Cache-friendly: changing commit hash busts cache appropriately
|
||||||
|
# - Supports both CUDA and Vulkan backends (requires explicit flag)
|
||||||
|
#
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
BACKEND=""
|
||||||
|
NO_CACHE=false
|
||||||
|
|
||||||
|
if [[ $# -eq 0 ]]; then
|
||||||
|
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||||
|
echo ""
|
||||||
|
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||||
|
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||||
|
echo " --no-cache Force rebuild without using Docker cache"
|
||||||
|
echo " --help, -h Show this help message"
|
||||||
|
echo ""
|
||||||
|
echo "Environment variables:"
|
||||||
|
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
|
||||||
|
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
|
||||||
|
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
|
||||||
|
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
for arg in "$@"; do
|
||||||
|
case $arg in
|
||||||
|
--cuda)
|
||||||
|
BACKEND="cuda"
|
||||||
|
;;
|
||||||
|
--vulkan)
|
||||||
|
BACKEND="vulkan"
|
||||||
|
;;
|
||||||
|
--no-cache)
|
||||||
|
NO_CACHE=true
|
||||||
|
;;
|
||||||
|
--help|-h)
|
||||||
|
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||||
|
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||||
|
echo " --no-cache Force rebuild without using Docker cache"
|
||||||
|
echo " --help, -h Show this help message"
|
||||||
|
echo ""
|
||||||
|
echo "Environment variables:"
|
||||||
|
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
|
||||||
|
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
|
||||||
|
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
|
||||||
|
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Validate backend selection
|
||||||
|
if [[ -z "$BACKEND" ]]; then
|
||||||
|
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
if [[ -n "${DOCKER_IMAGE_TAG:-}" ]]; then
|
||||||
|
# User provided a custom tag, use it as-is
|
||||||
|
:
|
||||||
|
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||||
|
DOCKER_IMAGE_TAG="llama-swap:vulkan"
|
||||||
|
else
|
||||||
|
DOCKER_IMAGE_TAG="llama-swap:cuda"
|
||||||
|
fi
|
||||||
|
DOCKER_BUILDKIT="${DOCKER_BUILDKIT:-1}"
|
||||||
|
|
||||||
|
# Single unified Dockerfile, backend selected via build arg
|
||||||
|
DOCKERFILE="Dockerfile"
|
||||||
|
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||||
|
echo "Building for: Vulkan (AMD GPUs and compatible hardware)"
|
||||||
|
else
|
||||||
|
echo "Building for: CUDA (NVIDIA GPUs)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Git repository URLs
|
||||||
|
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||||
|
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
|
||||||
|
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
|
||||||
|
|
||||||
|
# Function to get the latest commit hash from a git repo's default branch
|
||||||
|
get_latest_commit() {
|
||||||
|
local repo_url="$1"
|
||||||
|
local branch="${2:-master}"
|
||||||
|
|
||||||
|
# Try to get the latest commit hash for the specified branch
|
||||||
|
git ls-remote --heads "${repo_url}" "${branch}" 2>/dev/null | head -1 | cut -f1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to get the default branch name (master or main)
|
||||||
|
get_default_branch() {
|
||||||
|
local repo_url="$1"
|
||||||
|
|
||||||
|
# Check for master first
|
||||||
|
if git ls-remote --heads "${repo_url}" master &>/dev/null; then
|
||||||
|
echo "master"
|
||||||
|
elif git ls-remote --heads "${repo_url}" main &>/dev/null; then
|
||||||
|
echo "main"
|
||||||
|
else
|
||||||
|
echo "master" # fallback
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to get the latest release tag from a GitHub repo
|
||||||
|
get_latest_release_tag() {
|
||||||
|
local owner_repo="$1"
|
||||||
|
curl -fsSL "https://api.github.com/repos/${owner_repo}/releases/latest" \
|
||||||
|
| grep '"tag_name"' | head -1 | cut -d'"' -f4
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "llama-swap-docker Build Script"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Determine commit hashes / release tags - use env vars or auto-detect
|
||||||
|
# For vulkan builds, llama and sd use GitHub release tags (prebuilt binaries).
|
||||||
|
# For cuda builds (or whisper on any backend), use git commit hashes.
|
||||||
|
if [[ -n "${LLAMA_COMMIT_HASH:-}" ]]; then
|
||||||
|
LLAMA_HASH="${LLAMA_COMMIT_HASH}"
|
||||||
|
echo "llama.cpp: Using provided version: ${LLAMA_HASH}"
|
||||||
|
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||||
|
LLAMA_HASH=$(get_latest_release_tag "ggml-org/llama.cpp")
|
||||||
|
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest release tag for llama.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "llama.cpp: Auto-detected latest release tag: ${LLAMA_HASH}"
|
||||||
|
else
|
||||||
|
LLAMA_BRANCH=$(get_default_branch "${LLAMA_REPO}")
|
||||||
|
LLAMA_HASH=$(get_latest_commit "${LLAMA_REPO}" "${LLAMA_BRANCH}")
|
||||||
|
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "llama.cpp: Auto-detected latest commit (${LLAMA_BRANCH}): ${LLAMA_HASH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${WHISPER_COMMIT_HASH:-}" ]]; then
|
||||||
|
WHISPER_HASH="${WHISPER_COMMIT_HASH}"
|
||||||
|
echo "whisper.cpp: Using provided commit hash: ${WHISPER_HASH}"
|
||||||
|
else
|
||||||
|
WHISPER_BRANCH=$(get_default_branch "${WHISPER_REPO}")
|
||||||
|
WHISPER_HASH=$(get_latest_commit "${WHISPER_REPO}" "${WHISPER_BRANCH}")
|
||||||
|
if [[ -z "${WHISPER_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "whisper.cpp: Auto-detected latest commit (${WHISPER_BRANCH}): ${WHISPER_HASH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${SD_COMMIT_HASH:-}" ]]; then
|
||||||
|
SD_HASH="${SD_COMMIT_HASH}"
|
||||||
|
echo "stable-diffusion.cpp: Using provided version: ${SD_HASH}"
|
||||||
|
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||||
|
SD_HASH=$(get_latest_release_tag "leejet/stable-diffusion.cpp")
|
||||||
|
if [[ -z "${SD_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest release tag for stable-diffusion.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "stable-diffusion.cpp: Auto-detected latest release tag: ${SD_HASH}"
|
||||||
|
else
|
||||||
|
SD_BRANCH=$(get_default_branch "${SD_REPO}")
|
||||||
|
SD_HASH=$(get_latest_commit "${SD_REPO}" "${SD_BRANCH}")
|
||||||
|
if [[ -z "${SD_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "stable-diffusion.cpp: Auto-detected latest commit (${SD_BRANCH}): ${SD_HASH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Starting Docker build..."
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Build the Docker image with commit hashes as build args
|
||||||
|
# Build context is the repository root (..) so the Dockerfile can access Go source
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||||
|
BUILD_ARGS=(
|
||||||
|
--build-arg "BACKEND=${BACKEND}"
|
||||||
|
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
|
||||||
|
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
|
||||||
|
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
|
||||||
|
-t "${DOCKER_IMAGE_TAG}"
|
||||||
|
-f "${SCRIPT_DIR}/${DOCKERFILE}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ "$NO_CACHE" == true ]]; then
|
||||||
|
BUILD_ARGS+=(--no-cache)
|
||||||
|
echo "Note: Building without cache"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Use docker buildx with a custom builder for parallelism control
|
||||||
|
# The legacy DOCKER_BUILDKIT=1 docker build doesn't respect BUILDKIT_MAX_PARALLELISM env var
|
||||||
|
# We need to use a custom builder with a buildkitd.toml config file
|
||||||
|
BUILDER_NAME="llama-swap-builder"
|
||||||
|
|
||||||
|
# Check if our custom builder exists with the right config, create/update if needed
|
||||||
|
if ! docker buildx inspect "$BUILDER_NAME" >/dev/null 2>&1; then
|
||||||
|
echo "Creating custom buildx builder with max-parallelism=1..."
|
||||||
|
|
||||||
|
# Create buildkitd.toml config file
|
||||||
|
cat > buildkitd.toml << 'BUILDKIT_EOF'
|
||||||
|
[worker.oci]
|
||||||
|
max-parallelism = 1
|
||||||
|
BUILDKIT_EOF
|
||||||
|
|
||||||
|
# Create the builder with the config
|
||||||
|
docker buildx create --name "$BUILDER_NAME" \
|
||||||
|
--driver docker-container \
|
||||||
|
--buildkitd-config buildkitd.toml \
|
||||||
|
--use
|
||||||
|
else
|
||||||
|
# Switch to our builder
|
||||||
|
docker buildx use "$BUILDER_NAME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Building with sequential stages (one at a time), each using all CPU cores..."
|
||||||
|
echo "Using builder: $BUILDER_NAME"
|
||||||
|
|
||||||
|
# Use docker buildx build with --load to load the image into Docker
|
||||||
|
# The --builder flag ensures we use our custom builder with max-parallelism=1
|
||||||
|
# Build context is the repository root so we can access Go source files
|
||||||
|
docker buildx build --builder "$BUILDER_NAME" --load "${BUILD_ARGS[@]}" "${REPO_ROOT}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Verifying build artifacts..."
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Verify all expected binaries exist in the image
|
||||||
|
MISSING_BINARIES=()
|
||||||
|
|
||||||
|
for binary in llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap; do
|
||||||
|
if ! docker run --rm "${DOCKER_IMAGE_TAG}" which "${binary}" >/dev/null 2>&1; then
|
||||||
|
MISSING_BINARIES+=("${binary}")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
|
||||||
|
echo "ERROR: Build succeeded but the following binaries are missing from the image:"
|
||||||
|
for binary in "${MISSING_BINARIES[@]}"; do
|
||||||
|
echo " - ${binary}"
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
echo "This usually indicates a build stage failure. Try running with --no-cache flag:"
|
||||||
|
echo " ./build-image.sh --vulkan --no-cache"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "All expected binaries verified: llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Build complete!"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
echo "Image tag: ${DOCKER_IMAGE_TAG}"
|
||||||
|
echo ""
|
||||||
|
echo "Built with:"
|
||||||
|
echo " llama.cpp: ${LLAMA_HASH}"
|
||||||
|
echo " whisper.cpp: ${WHISPER_HASH}"
|
||||||
|
echo " stable-diffusion.cpp: ${SD_HASH}"
|
||||||
|
echo " llama-swap: $(docker run --rm "${DOCKER_IMAGE_TAG}" cat /versions.txt | grep llama-swap | cut -d' ' -f2-)"
|
||||||
|
echo ""
|
||||||
|
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||||
|
echo "Run with:"
|
||||||
|
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
|
||||||
|
echo ""
|
||||||
|
echo "Note: For AMD GPUs, you may also need to mount render devices:"
|
||||||
|
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
|
||||||
|
else
|
||||||
|
echo "Run with:"
|
||||||
|
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
|
||||||
|
fi
|
||||||
@@ -15,4 +15,19 @@ models:
|
|||||||
cmd: >
|
cmd: >
|
||||||
/app/llama-server
|
/app/llama-server
|
||||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
--port 9999
|
--port 9999
|
||||||
|
|
||||||
|
z-image:
|
||||||
|
checkEndpoint: /
|
||||||
|
cmd: |
|
||||||
|
/app/sd-server
|
||||||
|
--listen-port 9999
|
||||||
|
--diffusion-fa
|
||||||
|
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||||
|
--vae /models/ae.safetensors
|
||||||
|
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||||
|
--offload-to-cpu
|
||||||
|
--cfg-scale 1.0
|
||||||
|
--height 512 --width 512
|
||||||
|
--steps 8
|
||||||
|
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
|
||||||
|
ARG SD_TAG=master-vulkan
|
||||||
|
ARG BASE=llama-swap:latest
|
||||||
|
|
||||||
|
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
|
||||||
|
FROM ${BASE}
|
||||||
|
|
||||||
|
ARG UID=10001
|
||||||
|
ARG GID=10001
|
||||||
|
|
||||||
|
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
|
||||||
@@ -29,6 +29,10 @@ RUN chown --recursive $UID:$GID $HOME /app
|
|||||||
USER $UID:$GID
|
USER $UID:$GID
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Add /app to PATH
|
||||||
|
ENV PATH="/app:${PATH}"
|
||||||
|
|
||||||
RUN \
|
RUN \
|
||||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||||
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||||
|
|||||||
@@ -0,0 +1,207 @@
|
|||||||
|
# Unified multi-stage Dockerfile for AI inference tools
|
||||||
|
# Supports CUDA and Vulkan backends via BACKEND build arg
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# docker buildx build --build-arg BACKEND=cuda -t llama-swap:unified-cuda .
|
||||||
|
# docker buildx build --build-arg BACKEND=vulkan -t llama-swap:unified-vulkan .
|
||||||
|
# docker buildx build --build-arg BACKEND=cuda --build-arg CMAKE_CUDA_ARCHITECTURES="86;89" -t llama-swap:unified-cuda .
|
||||||
|
#
|
||||||
|
# Each project has its own install script that handles cloning, building,
|
||||||
|
# and installing binaries. Build stages are independent for cache efficiency.
|
||||||
|
|
||||||
|
ARG BACKEND=cuda
|
||||||
|
|
||||||
|
# ── Builder bases ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 AS builder-base-cuda
|
||||||
|
|
||||||
|
ARG CMAKE_CUDA_ARCHITECTURES="60;61;75;86;89"
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}
|
||||||
|
ENV CCACHE_DIR=/ccache
|
||||||
|
ENV CCACHE_MAXSIZE=2G
|
||||||
|
ENV PATH="/usr/lib/ccache:${PATH}"
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential cmake git python3 python3-pip libssl-dev \
|
||||||
|
curl ca-certificates ccache make wget \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# ──
|
||||||
|
|
||||||
|
FROM ubuntu:24.04 AS builder-base-vulkan
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV CCACHE_DIR=/ccache
|
||||||
|
ENV CCACHE_MAXSIZE=2G
|
||||||
|
ENV PATH="/usr/lib/ccache:${PATH}"
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential cmake git python3 python3-pip libssl-dev \
|
||||||
|
curl ca-certificates ccache make wget software-properties-common \
|
||||||
|
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
|
||||||
|
spirv-headers \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# ── Select builder base by BACKEND ────────────────────────────────────
|
||||||
|
|
||||||
|
FROM builder-base-${BACKEND} AS builder-base
|
||||||
|
|
||||||
|
# ── Build whisper.cpp (fastest build, run first) ──────────────────────
|
||||||
|
|
||||||
|
FROM builder-base AS whisper-build
|
||||||
|
ARG BACKEND=cuda
|
||||||
|
ARG WHISPER_COMMIT_HASH=master
|
||||||
|
COPY install-whisper.sh /build/
|
||||||
|
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||||
|
--mount=type=cache,id=whisper-${BACKEND},target=/src/whisper.cpp/build \
|
||||||
|
BACKEND=${BACKEND} bash /build/install-whisper.sh "${WHISPER_COMMIT_HASH}"
|
||||||
|
|
||||||
|
# ── Build stable-diffusion.cpp ────────────────────────────────────────
|
||||||
|
|
||||||
|
FROM builder-base AS sd-build
|
||||||
|
ARG BACKEND=cuda
|
||||||
|
ARG SD_COMMIT_HASH=master
|
||||||
|
COPY install-sd.sh /build/
|
||||||
|
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||||
|
--mount=type=cache,id=sd-${BACKEND},target=/src/stable-diffusion.cpp/build \
|
||||||
|
BACKEND=${BACKEND} bash /build/install-sd.sh "${SD_COMMIT_HASH}"
|
||||||
|
|
||||||
|
# ── Build llama.cpp (slowest build, run last) ─────────────────────────
|
||||||
|
|
||||||
|
FROM builder-base AS llama-build
|
||||||
|
ARG BACKEND=cuda
|
||||||
|
ARG LLAMA_COMMIT_HASH=master
|
||||||
|
COPY install-llama.sh /build/
|
||||||
|
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||||
|
--mount=type=cache,id=llama-${BACKEND},target=/src/llama.cpp/build \
|
||||||
|
BACKEND=${BACKEND} bash /build/install-llama.sh "${LLAMA_COMMIT_HASH}"
|
||||||
|
|
||||||
|
# ── Build ik_llama.cpp (CUDA only) ────────────────────────────────────
|
||||||
|
#
|
||||||
|
# Two named stages allow ARG BACKEND to select at build time:
|
||||||
|
# - ik-llama-cuda : real build (from builder-base-cuda)
|
||||||
|
# - ik-llama-vulkan: no-op (empty /install/bin, skips CUDA pull entirely)
|
||||||
|
# BuildKit only evaluates the selected branch, so vulkan builds never
|
||||||
|
# pull nvidia/cuda:*-devel or compile ik_llama.cpp.
|
||||||
|
|
||||||
|
FROM builder-base-vulkan AS ik-llama-vulkan
|
||||||
|
RUN mkdir -p /install/bin
|
||||||
|
|
||||||
|
FROM builder-base-cuda AS ik-llama-cuda
|
||||||
|
ARG IK_LLAMA_COMMIT_HASH=main
|
||||||
|
COPY install-ik-llama.sh /build/
|
||||||
|
RUN --mount=type=cache,id=ccache-cuda,target=/ccache \
|
||||||
|
--mount=type=cache,id=ik-llama-cuda,target=/src/ik_llama.cpp/build \
|
||||||
|
bash /build/install-ik-llama.sh "${IK_LLAMA_COMMIT_HASH}"
|
||||||
|
|
||||||
|
ARG BACKEND=cuda
|
||||||
|
FROM ik-llama-${BACKEND} AS ik-llama-build
|
||||||
|
|
||||||
|
# ── Download llama-swap release binary ────────────────────────────────
|
||||||
|
|
||||||
|
FROM builder-base AS llama-swap-download
|
||||||
|
ARG LS_VERSION=latest
|
||||||
|
COPY install-llama-swap.sh /build/
|
||||||
|
RUN bash /build/install-llama-swap.sh "${LS_VERSION}"
|
||||||
|
|
||||||
|
# ── Runtime bases ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04 AS runtime-cuda
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
|
||||||
|
ENV PATH="/usr/local/bin:${PATH}"
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libgomp1 python3 curl ca-certificates \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# CUDA stub drivers for container compatibility
|
||||||
|
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so
|
||||||
|
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
|
||||||
|
|
||||||
|
# ──
|
||||||
|
|
||||||
|
FROM ubuntu:24.04 AS runtime-vulkan
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PATH="/usr/local/bin:${PATH}"
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libgomp1 libvulkan1 mesa-vulkan-drivers \
|
||||||
|
python3 curl ca-certificates \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# ── Select runtime base by BACKEND ────────────────────────────────────
|
||||||
|
|
||||||
|
FROM runtime-${BACKEND} AS runtime
|
||||||
|
|
||||||
|
ARG BACKEND=cuda
|
||||||
|
ARG LLAMA_COMMIT_HASH=unknown
|
||||||
|
ARG WHISPER_COMMIT_HASH=unknown
|
||||||
|
ARG SD_COMMIT_HASH=unknown
|
||||||
|
ARG IK_LLAMA_COMMIT_HASH=unknown
|
||||||
|
ARG RUN_UID=0
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3-numpy python3-sentencepiece python3-pip \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create non-root user when RUN_UID != 0
|
||||||
|
RUN if [ "$RUN_UID" != "0" ]; then \
|
||||||
|
groupadd --system --gid $RUN_UID llama-swap && \
|
||||||
|
useradd --system --uid $RUN_UID --gid $RUN_UID \
|
||||||
|
--home /app --shell /sbin/nologin llama-swap; \
|
||||||
|
fi && \
|
||||||
|
mkdir -p /etc/llama-swap/config && \
|
||||||
|
chown -R ${RUN_UID}:${RUN_UID} /etc/llama-swap
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy whisper.cpp binaries and libraries
|
||||||
|
COPY --from=whisper-build /install/bin/whisper-server /usr/local/bin/
|
||||||
|
COPY --from=whisper-build /install/bin/whisper-cli /usr/local/bin/
|
||||||
|
COPY --from=whisper-build /install/lib/ /usr/local/lib/
|
||||||
|
|
||||||
|
# Copy stable-diffusion.cpp binaries and libraries
|
||||||
|
COPY --from=sd-build /install/bin/sd-server /usr/local/bin/
|
||||||
|
COPY --from=sd-build /install/bin/sd-cli /usr/local/bin/
|
||||||
|
COPY --from=sd-build /install/lib/ /usr/local/lib/
|
||||||
|
|
||||||
|
# Copy llama.cpp binaries (statically linked)
|
||||||
|
COPY --from=llama-build /install/bin/llama-server /usr/local/bin/
|
||||||
|
COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
|
||||||
|
|
||||||
|
# Copy ik-llama-server (CUDA only; empty copy for vulkan)
|
||||||
|
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
|
||||||
|
|
||||||
|
# Install uv
|
||||||
|
RUN pip install uv --break-system-packages
|
||||||
|
|
||||||
|
# Copy llama-swap binary
|
||||||
|
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
|
||||||
|
COPY --from=llama-swap-download /install/llama-swap-version /tmp/
|
||||||
|
|
||||||
|
RUN ldconfig
|
||||||
|
|
||||||
|
COPY config.example.yaml /etc/llama-swap/config/config.yaml
|
||||||
|
|
||||||
|
# Version tracking
|
||||||
|
RUN echo "llama.cpp: ${LLAMA_COMMIT_HASH}" > /versions.txt && \
|
||||||
|
echo "whisper.cpp: ${WHISPER_COMMIT_HASH}" >> /versions.txt && \
|
||||||
|
echo "stable-diffusion.cpp: ${SD_COMMIT_HASH}" >> /versions.txt && \
|
||||||
|
echo "ik_llama.cpp: ${IK_LLAMA_COMMIT_HASH}" >> /versions.txt && \
|
||||||
|
echo "llama-swap: $(cat /tmp/llama-swap-version)" >> /versions.txt && \
|
||||||
|
echo "backend: ${BACKEND}" >> /versions.txt && \
|
||||||
|
echo "build_timestamp: $(date -u +%Y-%m-%dT%H:%M:%SZ)" >> /versions.txt
|
||||||
|
|
||||||
|
RUN mkdir -p /models && chown ${RUN_UID}:${RUN_UID} /models
|
||||||
|
WORKDIR /models
|
||||||
|
USER ${RUN_UID}
|
||||||
|
ENTRYPOINT ["llama-swap"]
|
||||||
|
CMD ["-config", "/etc/llama-swap/config/config.yaml", "-listen", "0.0.0.0:8080"]
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# Unified Docker Container
|
||||||
|
|
||||||
|
These scripts create a custom llama-swap container that contains:
|
||||||
|
|
||||||
|
- llama-server for LLMs, rerank and embedding model support
|
||||||
|
- sd-server (stable-diffusion.cpp) for image generation
|
||||||
|
- whisper.cpp for ASR
|
||||||
|
|
||||||
@@ -0,0 +1,303 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Build script for unified container with version pinning
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./build-image.sh --cuda # Build CUDA image
|
||||||
|
# ./build-image.sh --vulkan # Build Vulkan image
|
||||||
|
# ./build-image.sh --cuda --no-cache # Build without cache
|
||||||
|
# LLAMA_REF=b1234 ./build-image.sh --vulkan # Pin llama.cpp to a commit hash
|
||||||
|
# LLAMA_REF=v1.2.3 ./build-image.sh --cuda # Pin llama.cpp to a tag
|
||||||
|
# WHISPER_REF=v1.0.0 ./build-image.sh --vulkan # Pin whisper.cpp to a tag
|
||||||
|
# SD_REF=master ./build-image.sh --cuda # Pin stable-diffusion.cpp to a branch
|
||||||
|
# LS_VERSION=170 ./build-image.sh --cuda # Override llama-swap version
|
||||||
|
# IK_LLAMA_REF=main ./build-image.sh --cuda # Pin ik_llama.cpp to main branch (CUDA only)
|
||||||
|
#
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
BACKEND=""
|
||||||
|
NO_CACHE=false
|
||||||
|
|
||||||
|
for arg in "$@"; do
|
||||||
|
case $arg in
|
||||||
|
--cuda)
|
||||||
|
BACKEND="cuda"
|
||||||
|
;;
|
||||||
|
--vulkan)
|
||||||
|
BACKEND="vulkan"
|
||||||
|
;;
|
||||||
|
--no-cache)
|
||||||
|
NO_CACHE=true
|
||||||
|
;;
|
||||||
|
--help|-h)
|
||||||
|
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||||
|
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||||
|
echo " --no-cache Force rebuild without using Docker cache"
|
||||||
|
echo " --help, -h Show this help message"
|
||||||
|
echo ""
|
||||||
|
echo "Environment variables:"
|
||||||
|
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:unified-cuda or llama-swap:unified-vulkan)"
|
||||||
|
echo " LLAMA_REF Pin llama.cpp to a commit, tag, or branch"
|
||||||
|
echo " WHISPER_REF Pin whisper.cpp to a commit, tag, or branch"
|
||||||
|
echo " SD_REF Pin stable-diffusion.cpp to a commit, tag, or branch"
|
||||||
|
echo " IK_LLAMA_REF Pin ik_llama.cpp to a commit, tag, or branch (CUDA only)"
|
||||||
|
echo " LS_VERSION Override llama-swap version (e.g., '170' or 'latest')"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ -z "$BACKEND" ]]; then
|
||||||
|
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||||
|
echo ""
|
||||||
|
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
DOCKER_IMAGE_TAG="${DOCKER_IMAGE_TAG:-llama-swap:unified-${BACKEND}}"
|
||||||
|
|
||||||
|
# Git repository URLs
|
||||||
|
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||||
|
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
|
||||||
|
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
|
||||||
|
LLAMA_SWAP_REPO="https://github.com/mostlygeek/llama-swap.git"
|
||||||
|
IK_LLAMA_REPO="https://github.com/ikawrakow/ik_llama.cpp.git"
|
||||||
|
|
||||||
|
# Resolve a git ref (commit hash, tag, or branch) to a full commit hash.
|
||||||
|
# Requires only: git, network access to the remote.
|
||||||
|
resolve_ref() {
|
||||||
|
local repo_url="$1"
|
||||||
|
local ref="$2"
|
||||||
|
|
||||||
|
# Full 40-char SHA — use as-is
|
||||||
|
if [[ "${ref}" =~ ^[0-9a-f]{40}$ ]]; then
|
||||||
|
echo "${ref}"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Try tag then branch (exact match)
|
||||||
|
local hash
|
||||||
|
hash=$(git ls-remote "${repo_url}" "refs/tags/${ref}" "refs/heads/${ref}" 2>/dev/null | head -1 | cut -f1)
|
||||||
|
if [[ -n "${hash}" ]]; then
|
||||||
|
echo "${hash}"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Short hash (7+ chars): scan all refs for a SHA with this prefix
|
||||||
|
if [[ "${ref}" =~ ^[0-9a-f]{7,}$ ]]; then
|
||||||
|
hash=$(git ls-remote "${repo_url}" 2>/dev/null | grep "^${ref}" | head -1 | cut -f1)
|
||||||
|
if [[ -n "${hash}" ]]; then
|
||||||
|
echo "${hash}"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "ERROR: Could not resolve ref '${ref}' for ${repo_url}" >&2
|
||||||
|
if [[ "${ref}" =~ ^[0-9a-f]+$ && ${#ref} -lt 7 ]]; then
|
||||||
|
echo " Short hashes must be at least 7 characters (got ${#ref})." >&2
|
||||||
|
else
|
||||||
|
echo " Tried: tag, branch, git ls-remote prefix match" >&2
|
||||||
|
fi
|
||||||
|
echo " Use a full 40-char SHA, a tag name, a branch name, or a 7+ char short hash." >&2
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Resolve HEAD of a repo without needing to know the default branch name.
|
||||||
|
get_latest_hash() {
|
||||||
|
git ls-remote "${1}" HEAD 2>/dev/null | head -1 | cut -f1
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "llama-swap Unified Build (${BACKEND})"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Resolve llama.cpp ref
|
||||||
|
if [[ -n "${LLAMA_REF:-}" ]]; then
|
||||||
|
LLAMA_HASH=$(resolve_ref "${LLAMA_REPO}" "${LLAMA_REF}") || exit 1
|
||||||
|
echo "llama.cpp: ${LLAMA_REF} -> ${LLAMA_HASH}"
|
||||||
|
else
|
||||||
|
LLAMA_HASH=$(get_latest_hash "${LLAMA_REPO}")
|
||||||
|
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "llama.cpp: latest HEAD: ${LLAMA_HASH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Resolve whisper.cpp ref
|
||||||
|
if [[ -n "${WHISPER_REF:-}" ]]; then
|
||||||
|
WHISPER_HASH=$(resolve_ref "${WHISPER_REPO}" "${WHISPER_REF}") || exit 1
|
||||||
|
echo "whisper.cpp: ${WHISPER_REF} -> ${WHISPER_HASH}"
|
||||||
|
else
|
||||||
|
WHISPER_HASH=$(get_latest_hash "${WHISPER_REPO}")
|
||||||
|
if [[ -z "${WHISPER_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "whisper.cpp: latest HEAD: ${WHISPER_HASH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Resolve stable-diffusion.cpp ref
|
||||||
|
if [[ -n "${SD_REF:-}" ]]; then
|
||||||
|
SD_HASH=$(resolve_ref "${SD_REPO}" "${SD_REF}") || exit 1
|
||||||
|
echo "stable-diffusion.cpp: ${SD_REF} -> ${SD_HASH}"
|
||||||
|
else
|
||||||
|
SD_HASH=$(get_latest_hash "${SD_REPO}")
|
||||||
|
if [[ -z "${SD_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "stable-diffusion.cpp: latest HEAD: ${SD_HASH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Resolve ik_llama.cpp ref (CUDA only)
|
||||||
|
if [[ "$BACKEND" == "cuda" ]]; then
|
||||||
|
if [[ -n "${IK_LLAMA_REF:-}" ]]; then
|
||||||
|
IK_LLAMA_HASH=$(resolve_ref "${IK_LLAMA_REPO}" "${IK_LLAMA_REF}") || exit 1
|
||||||
|
echo "ik_llama.cpp: ${IK_LLAMA_REF} -> ${IK_LLAMA_HASH}"
|
||||||
|
else
|
||||||
|
IK_LLAMA_HASH=$(get_latest_hash "${IK_LLAMA_REPO}")
|
||||||
|
if [[ -z "${IK_LLAMA_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest commit for ik_llama.cpp" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "ik_llama.cpp: latest HEAD: ${IK_LLAMA_HASH}"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
IK_LLAMA_HASH="n/a"
|
||||||
|
echo "ik_llama.cpp: skipped (vulkan build)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Resolve llama-swap ref
|
||||||
|
if [[ -n "${LS_VERSION:-}" ]]; then
|
||||||
|
LS_HASH=$(resolve_ref "${LLAMA_SWAP_REPO}" "${LS_VERSION}") || exit 1
|
||||||
|
echo "llama-swap: ${LS_VERSION} -> ${LS_HASH}"
|
||||||
|
else
|
||||||
|
LS_HASH=$(get_latest_hash "${LLAMA_SWAP_REPO}")
|
||||||
|
if [[ -z "${LS_HASH}" ]]; then
|
||||||
|
echo "ERROR: Could not determine latest commit for llama-swap" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "llama-swap: latest HEAD: ${LS_HASH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Starting Docker build..."
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
|
BUILD_ARGS=(
|
||||||
|
--build-arg "BACKEND=${BACKEND}"
|
||||||
|
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
|
||||||
|
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
|
||||||
|
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
|
||||||
|
--build-arg "IK_LLAMA_COMMIT_HASH=${IK_LLAMA_HASH}"
|
||||||
|
--build-arg "LS_VERSION=${LS_HASH}"
|
||||||
|
-t "${DOCKER_IMAGE_TAG}"
|
||||||
|
-f "${SCRIPT_DIR}/Dockerfile"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ "$NO_CACHE" == true ]]; then
|
||||||
|
BUILD_ARGS+=(--no-cache)
|
||||||
|
echo "Note: Building without cache"
|
||||||
|
elif [[ "${GITHUB_ACTIONS:-}" == "true" && "${ACT:-}" != "true" ]]; then
|
||||||
|
CACHE_REF="ghcr.io/mostlygeek/llama-swap:unified-${BACKEND}-cache"
|
||||||
|
BUILD_ARGS+=(
|
||||||
|
--cache-from "type=registry,ref=${CACHE_REF}"
|
||||||
|
--cache-to "type=registry,ref=${CACHE_REF},mode=max"
|
||||||
|
)
|
||||||
|
echo "Note: Using registry cache (${CACHE_REF})"
|
||||||
|
fi
|
||||||
|
|
||||||
|
DOCKER_BUILDKIT=1 docker buildx build --load "${BUILD_ARGS[@]}" "${SCRIPT_DIR}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Verifying build artifacts..."
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
EXPECTED_BINARIES=(llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap)
|
||||||
|
if [[ "$BACKEND" == "cuda" ]]; then
|
||||||
|
EXPECTED_BINARIES+=(ik-llama-server)
|
||||||
|
fi
|
||||||
|
|
||||||
|
MISSING_BINARIES=()
|
||||||
|
for binary in "${EXPECTED_BINARIES[@]}"; do
|
||||||
|
if ! docker run --rm --entrypoint which "${DOCKER_IMAGE_TAG}" "${binary}" >/dev/null 2>&1; then
|
||||||
|
MISSING_BINARIES+=("${binary}")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
|
||||||
|
echo "ERROR: Build succeeded but the following binaries are missing:"
|
||||||
|
for binary in "${MISSING_BINARIES[@]}"; do
|
||||||
|
echo " - ${binary}"
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
echo "Try running with --no-cache flag:"
|
||||||
|
echo " ./build-image.sh --${BACKEND} --no-cache"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
VERIFIED_LIST="llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
|
||||||
|
if [[ "$BACKEND" == "cuda" ]]; then
|
||||||
|
VERIFIED_LIST="${VERIFIED_LIST}, ik-llama-server"
|
||||||
|
fi
|
||||||
|
echo "All expected binaries verified: ${VERIFIED_LIST}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Building rootless image..."
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
ROOTLESS_TAG="${DOCKER_IMAGE_TAG}-rootless"
|
||||||
|
docker buildx build --load -t "${ROOTLESS_TAG}" - <<EOF
|
||||||
|
FROM ${DOCKER_IMAGE_TAG}
|
||||||
|
USER root
|
||||||
|
RUN groupadd --system --gid 10001 llama-swap && \\
|
||||||
|
useradd --system --uid 10001 --gid 10001 \\
|
||||||
|
--home /app --shell /sbin/nologin llama-swap && \\
|
||||||
|
chown -R 10001:10001 /etc/llama-swap /models
|
||||||
|
USER 10001
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "Rootless image built: ${ROOTLESS_TAG}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Build complete!"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
echo "Image tags:"
|
||||||
|
echo " ${DOCKER_IMAGE_TAG}"
|
||||||
|
echo " ${ROOTLESS_TAG}"
|
||||||
|
echo ""
|
||||||
|
echo "Built with:"
|
||||||
|
echo " llama.cpp: ${LLAMA_HASH}"
|
||||||
|
echo " whisper.cpp: ${WHISPER_HASH}"
|
||||||
|
echo " stable-diffusion.cpp: ${SD_HASH}"
|
||||||
|
if [[ "$BACKEND" == "cuda" ]]; then
|
||||||
|
echo " ik_llama.cpp: ${IK_LLAMA_HASH}"
|
||||||
|
fi
|
||||||
|
echo " llama-swap: $(docker run --rm --entrypoint cat "${DOCKER_IMAGE_TAG}" /versions.txt | grep llama-swap | cut -d' ' -f2-)"
|
||||||
|
echo ""
|
||||||
|
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||||
|
echo "Run with:"
|
||||||
|
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
|
||||||
|
echo ""
|
||||||
|
echo "Note: For AMD GPUs, you may also need:"
|
||||||
|
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
|
||||||
|
else
|
||||||
|
echo "Run with:"
|
||||||
|
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
|
||||||
|
fi
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
# placeholder example configuration
|
||||||
|
healthCheckTimeout: 300
|
||||||
|
logRequests: true
|
||||||
|
|
||||||
|
models:
|
||||||
|
"llama":
|
||||||
|
cmd: >
|
||||||
|
llama-server
|
||||||
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
|
--port ${PORT}
|
||||||
|
|
||||||
|
"whisper":
|
||||||
|
checkEndpoint: /v1/audio/transcriptions/
|
||||||
|
cmd: >
|
||||||
|
whisper-server
|
||||||
|
--port ${PORT}
|
||||||
|
--m /models/whisper.bin
|
||||||
|
--flash-attn
|
||||||
|
--request-path /v1/audio/transcriptions --inference-path ""
|
||||||
|
|
||||||
|
"image":
|
||||||
|
checkEndpoint: /
|
||||||
|
cmd: |
|
||||||
|
/app/sd-server
|
||||||
|
--listen-port 9999
|
||||||
|
--diffusion-fa
|
||||||
|
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||||
|
--vae /models/ae.safetensors
|
||||||
|
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||||
|
--offload-to-cpu
|
||||||
|
--cfg-scale 1.0
|
||||||
|
--height 512 --width 512
|
||||||
|
--steps 8
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Install ik_llama.cpp - clone, build, and install binaries
|
||||||
|
# Usage: ./install-ik-llama.sh <commit_hash>
|
||||||
|
# Note: CUDA only; always built against builder-base-cuda
|
||||||
|
set -e
|
||||||
|
|
||||||
|
COMMIT_HASH="${1:-main}"
|
||||||
|
|
||||||
|
mkdir -p /install/bin
|
||||||
|
|
||||||
|
# Clone and checkout (init-based so cache-mounted build dir doesn't break clone)
|
||||||
|
echo "=== Cloning ik_llama.cpp at ${COMMIT_HASH} ==="
|
||||||
|
mkdir -p /src/ik_llama.cpp
|
||||||
|
cd /src/ik_llama.cpp
|
||||||
|
if [ ! -d .git ]; then
|
||||||
|
git init
|
||||||
|
git remote add origin https://github.com/ikawrakow/ik_llama.cpp.git
|
||||||
|
fi
|
||||||
|
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||||
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
|
CMAKE_FLAGS=(
|
||||||
|
-DGGML_NATIVE=OFF
|
||||||
|
-DBUILD_SHARED_LIBS=OFF
|
||||||
|
-DCMAKE_BUILD_TYPE=Release
|
||||||
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||||
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||||
|
-DGGML_CUDA=ON
|
||||||
|
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||||
|
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||||
|
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda -Wl,--allow-shlib-undefined"
|
||||||
|
)
|
||||||
|
|
||||||
|
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||||
|
|
||||||
|
echo "=== Building ik_llama.cpp ==="
|
||||||
|
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||||
|
cmake --build build --config Release -j"$(nproc)" --target llama-server
|
||||||
|
|
||||||
|
if [ ! -f "build/bin/llama-server" ]; then
|
||||||
|
echo "FATAL: llama-server not found in build/bin/" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install as ik-llama-server to avoid collision with llama.cpp's llama-server
|
||||||
|
cp "build/bin/llama-server" "/install/bin/ik-llama-server"
|
||||||
|
echo "=== ik_llama.cpp build complete ==="
|
||||||
|
ls -la /install/bin/
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Install llama-swap - download latest release binary from GitHub
|
||||||
|
# Usage: ./install-llama-swap.sh [version]
|
||||||
|
# version: release version number (e.g., "170") or "latest" (default)
|
||||||
|
set -e
|
||||||
|
|
||||||
|
VERSION="${1:-latest}"
|
||||||
|
REPO="mostlygeek/llama-swap"
|
||||||
|
|
||||||
|
mkdir -p /install/bin
|
||||||
|
|
||||||
|
# If a full commit hash is given, find the release tag that points to it
|
||||||
|
if echo "${VERSION}" | grep -qE '^[0-9a-f]{40}$'; then
|
||||||
|
echo "=== Resolving commit ${VERSION:0:7} to release tag ==="
|
||||||
|
TAG=$(git ls-remote --tags "https://github.com/${REPO}.git" 2>/dev/null \
|
||||||
|
| grep "^${VERSION}" | sed 's|.*refs/tags/||' | grep -v '\^{}' | head -1)
|
||||||
|
if [ -n "${TAG}" ]; then
|
||||||
|
echo "Resolved to tag: ${TAG}"
|
||||||
|
VERSION="${TAG#v}"
|
||||||
|
else
|
||||||
|
echo "No release tag found for commit ${VERSION:0:7}, using latest"
|
||||||
|
VERSION="latest"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Strip leading 'v' prefix so both "198" and "v198" work
|
||||||
|
VERSION="${VERSION#v}"
|
||||||
|
|
||||||
|
# Resolve "latest" to actual version number
|
||||||
|
if [ "$VERSION" = "latest" ]; then
|
||||||
|
echo "=== Resolving latest llama-swap release ==="
|
||||||
|
VERSION=$(curl -fsSL "https://api.github.com/repos/${REPO}/releases/latest" \
|
||||||
|
| grep '"tag_name"' | head -1 | cut -d'"' -f4 | sed 's/^v//')
|
||||||
|
if [ -z "$VERSION" ]; then
|
||||||
|
echo "FATAL: Could not determine latest release version" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Latest version: ${VERSION}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) ARCH="amd64" ;;
|
||||||
|
aarch64|arm64) ARCH="arm64" ;;
|
||||||
|
*) echo "FATAL: Unsupported architecture: $ARCH" >&2; exit 1 ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
# Download and extract
|
||||||
|
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_${ARCH}.tar.gz"
|
||||||
|
echo "=== Downloading llama-swap v${VERSION} ==="
|
||||||
|
echo "URL: $URL"
|
||||||
|
curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
|
||||||
|
tar -xzf /tmp/llama-swap.tar.gz -C /install/bin/
|
||||||
|
rm /tmp/llama-swap.tar.gz
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
if [ ! -x "/install/bin/llama-swap" ]; then
|
||||||
|
echo "FATAL: llama-swap binary not found or not executable" >&2
|
||||||
|
ls -la /install/bin/ >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$VERSION" > /install/llama-swap-version
|
||||||
|
|
||||||
|
echo "=== llama-swap v${VERSION} installed ==="
|
||||||
|
ls -la /install/bin/llama-swap
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Install llama.cpp - clone, build, and install binaries
|
||||||
|
# Usage: BACKEND=cuda|vulkan ./install-llama.sh <commit_hash>
|
||||||
|
set -e
|
||||||
|
|
||||||
|
COMMIT_HASH="${1:-master}"
|
||||||
|
BACKEND="${BACKEND:-cuda}"
|
||||||
|
|
||||||
|
mkdir -p /install/bin
|
||||||
|
|
||||||
|
# Clone and checkout (init-based so cache-mounted /src/llama.cpp/build dir doesn't break clone)
|
||||||
|
echo "=== Cloning llama.cpp at ${COMMIT_HASH} ==="
|
||||||
|
mkdir -p /src/llama.cpp
|
||||||
|
cd /src/llama.cpp
|
||||||
|
if [ ! -d .git ]; then
|
||||||
|
git init
|
||||||
|
git remote add origin https://github.com/ggml-org/llama.cpp.git
|
||||||
|
fi
|
||||||
|
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||||
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
|
# Common cmake flags
|
||||||
|
CMAKE_FLAGS=(
|
||||||
|
-DGGML_NATIVE=OFF
|
||||||
|
-DBUILD_SHARED_LIBS=OFF
|
||||||
|
-DCMAKE_BUILD_TYPE=Release
|
||||||
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||||
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||||
|
-DLLAMA_BUILD_TESTS=OFF
|
||||||
|
)
|
||||||
|
|
||||||
|
if [ "$BACKEND" = "cuda" ]; then
|
||||||
|
CMAKE_FLAGS+=(
|
||||||
|
-DGGML_CUDA=ON
|
||||||
|
-DGGML_VULKAN=OFF
|
||||||
|
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||||
|
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||||
|
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||||
|
)
|
||||||
|
elif [ "$BACKEND" = "vulkan" ]; then
|
||||||
|
CMAKE_FLAGS+=(
|
||||||
|
-DGGML_CUDA=OFF
|
||||||
|
-DGGML_VULKAN=ON
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
TARGETS=(llama-cli llama-server)
|
||||||
|
|
||||||
|
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||||
|
|
||||||
|
echo "=== Building llama.cpp for ${BACKEND} ==="
|
||||||
|
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||||
|
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||||
|
|
||||||
|
for bin in "${TARGETS[@]}"; do
|
||||||
|
if [ ! -f "build/bin/$bin" ]; then
|
||||||
|
echo "FATAL: $bin not found in build/bin/" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cp "build/bin/$bin" "/install/bin/"
|
||||||
|
done
|
||||||
|
echo "=== llama.cpp build complete ==="
|
||||||
|
ls -la /install/bin/
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Install stable-diffusion.cpp - clone, build, and install binaries and library
|
||||||
|
# Usage: BACKEND=cuda|vulkan ./install-sd.sh <commit_hash>
|
||||||
|
set -e
|
||||||
|
|
||||||
|
COMMIT_HASH="${1:-master}"
|
||||||
|
BACKEND="${BACKEND:-cuda}"
|
||||||
|
|
||||||
|
mkdir -p /install/bin /install/lib
|
||||||
|
|
||||||
|
# Clone and checkout (init-based so cache-mounted /src/stable-diffusion.cpp/build dir doesn't break clone)
|
||||||
|
echo "=== Cloning stable-diffusion.cpp at ${COMMIT_HASH} ==="
|
||||||
|
mkdir -p /src/stable-diffusion.cpp
|
||||||
|
cd /src/stable-diffusion.cpp
|
||||||
|
if [ ! -d .git ]; then
|
||||||
|
git init
|
||||||
|
git remote add origin https://github.com/leejet/stable-diffusion.cpp.git
|
||||||
|
fi
|
||||||
|
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||||
|
git checkout FETCH_HEAD
|
||||||
|
git submodule update --init --recursive --depth=1
|
||||||
|
|
||||||
|
# Common cmake flags
|
||||||
|
CMAKE_FLAGS=(
|
||||||
|
-DGGML_NATIVE=OFF
|
||||||
|
-DCMAKE_BUILD_TYPE=Release
|
||||||
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||||
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||||
|
-DSD_BUILD_EXAMPLES=ON
|
||||||
|
)
|
||||||
|
|
||||||
|
if [ "$BACKEND" = "cuda" ]; then
|
||||||
|
CMAKE_FLAGS+=(
|
||||||
|
-DGGML_CUDA=ON
|
||||||
|
-DGGML_VULKAN=OFF
|
||||||
|
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||||
|
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||||
|
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||||
|
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||||
|
-DSD_CUDA=ON
|
||||||
|
)
|
||||||
|
elif [ "$BACKEND" = "vulkan" ]; then
|
||||||
|
CMAKE_FLAGS+=(
|
||||||
|
-DGGML_CUDA=OFF
|
||||||
|
-DGGML_VULKAN=ON
|
||||||
|
-DSD_VULKAN=ON
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
TARGETS=(stable-diffusion sd-cli sd-server)
|
||||||
|
|
||||||
|
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||||
|
|
||||||
|
echo "=== Building stable-diffusion.cpp for ${BACKEND} ==="
|
||||||
|
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||||
|
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||||
|
|
||||||
|
for bin in sd-cli sd-server; do
|
||||||
|
if [ ! -f "build/bin/$bin" ]; then
|
||||||
|
echo "FATAL: $bin not found in build/bin/" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cp "build/bin/$bin" "/install/bin/"
|
||||||
|
done
|
||||||
|
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
|
||||||
|
|
||||||
|
echo "=== stable-diffusion.cpp build complete ==="
|
||||||
|
ls -la /install/bin/ /install/lib/
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Install whisper.cpp - clone, build, and install binaries
|
||||||
|
# Usage: BACKEND=cuda|vulkan ./install-whisper.sh <commit_hash>
|
||||||
|
set -e
|
||||||
|
|
||||||
|
COMMIT_HASH="${1:-master}"
|
||||||
|
BACKEND="${BACKEND:-cuda}"
|
||||||
|
|
||||||
|
mkdir -p /install/bin /install/lib
|
||||||
|
|
||||||
|
# Clone and checkout (init-based so cache-mounted /src/whisper.cpp/build dir doesn't break clone)
|
||||||
|
echo "=== Cloning whisper.cpp at ${COMMIT_HASH} ==="
|
||||||
|
mkdir -p /src/whisper.cpp
|
||||||
|
cd /src/whisper.cpp
|
||||||
|
if [ ! -d .git ]; then
|
||||||
|
git init
|
||||||
|
git remote add origin https://github.com/ggml-org/whisper.cpp.git
|
||||||
|
fi
|
||||||
|
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||||
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
|
# Common cmake flags
|
||||||
|
CMAKE_FLAGS=(
|
||||||
|
-DGGML_NATIVE=OFF
|
||||||
|
-DCMAKE_BUILD_TYPE=Release
|
||||||
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||||
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||||
|
)
|
||||||
|
|
||||||
|
if [ "$BACKEND" = "cuda" ]; then
|
||||||
|
CMAKE_FLAGS+=(
|
||||||
|
-DGGML_CUDA=ON
|
||||||
|
-DGGML_VULKAN=OFF
|
||||||
|
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||||
|
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||||
|
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||||
|
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||||
|
)
|
||||||
|
elif [ "$BACKEND" = "vulkan" ]; then
|
||||||
|
CMAKE_FLAGS+=(
|
||||||
|
-DGGML_CUDA=OFF
|
||||||
|
-DGGML_VULKAN=ON
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
TARGETS=(whisper-cli whisper-server)
|
||||||
|
|
||||||
|
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||||
|
|
||||||
|
echo "=== Building whisper.cpp for ${BACKEND} ==="
|
||||||
|
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||||
|
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||||
|
|
||||||
|
for bin in "${TARGETS[@]}"; do
|
||||||
|
if [ ! -f "build/bin/$bin" ]; then
|
||||||
|
echo "FATAL: $bin not found in build/bin/" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cp "build/bin/$bin" "/install/bin/"
|
||||||
|
done
|
||||||
|
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
|
||||||
|
|
||||||
|
echo "=== whisper.cpp build complete ==="
|
||||||
|
ls -la /install/bin/
|
||||||
|
Before Width: | Height: | Size: 261 KiB After Width: | Height: | Size: 261 KiB |
|
Before Width: | Height: | Size: 351 KiB After Width: | Height: | Size: 351 KiB |
|
After Width: | Height: | Size: 198 KiB |
@@ -22,7 +22,7 @@ models:
|
|||||||
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
|
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model. Useful if you want to run multiple models at the same time with the `groups` feature.
|
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model which is useful if you want to run multiple models at the same time with the `matrix` feature.
|
||||||
|
|
||||||
## Advanced control with `cmd`
|
## Advanced control with `cmd`
|
||||||
|
|
||||||
@@ -76,7 +76,7 @@ llama-swap supports many more features to customize how you want to manage your
|
|||||||
| --------- | ---------------------------------------------- |
|
| --------- | ---------------------------------------------- |
|
||||||
| `ttl` | automatic unloading of models after a timeout |
|
| `ttl` | automatic unloading of models after a timeout |
|
||||||
| `macros` | reusable snippets to use in configurations |
|
| `macros` | reusable snippets to use in configurations |
|
||||||
| `groups` | run multiple models at a time |
|
| `matrix` | run multiple models at a time |
|
||||||
| `hooks` | event driven functionality |
|
| `hooks` | event driven functionality |
|
||||||
| `env` | define environment variables per model |
|
| `env` | define environment variables per model |
|
||||||
| `aliases` | serve a model with different names |
|
| `aliases` | serve a model with different names |
|
||||||
@@ -86,9 +86,12 @@ llama-swap supports many more features to customize how you want to manage your
|
|||||||
## Full Configuration Example
|
## Full Configuration Example
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples.
|
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
# add this modeline for validation in vscode
|
||||||
|
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||||
|
#
|
||||||
# llama-swap YAML configuration example
|
# llama-swap YAML configuration example
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
#
|
#
|
||||||
@@ -114,18 +117,60 @@ healthCheckTimeout: 500
|
|||||||
# - Valid log levels: debug, info, warn, error
|
# - Valid log levels: debug, info, warn, error
|
||||||
logLevel: info
|
logLevel: info
|
||||||
|
|
||||||
|
# logTimeFormat: enables and sets the logging timestamp format
|
||||||
|
# - optional, default (disabled): ""
|
||||||
|
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||||
|
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||||
|
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||||
|
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||||
|
logTimeFormat: ""
|
||||||
|
|
||||||
|
# logToStdout: controls what is logged to stdout
|
||||||
|
# - optional, default: "proxy"
|
||||||
|
# - valid values:
|
||||||
|
# - "proxy": logs generated by llama-swap when swapping models,
|
||||||
|
# handling requests, etc.
|
||||||
|
# - "upstream": a copy of an upstream processes stdout logs
|
||||||
|
# - "both": both the proxy and upstream logs interleaved together
|
||||||
|
# - "none": no logs are ever written to stdout
|
||||||
|
logToStdout: "proxy"
|
||||||
|
|
||||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||||
# - optional, default: 1000
|
# - optional, default: 1000
|
||||||
# - controls how many metrics are stored in memory before older ones are discarded
|
# - controls how many metrics are stored in memory before older ones are discarded
|
||||||
# - useful for limiting memory usage when processing large volumes of metrics
|
# - useful for limiting memory usage when processing large volumes of metrics
|
||||||
metricsMaxInMemory: 1000
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
|
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||||
|
# - optional, default: 10
|
||||||
|
# - set to 0 to disable
|
||||||
|
captureBuffer: 15
|
||||||
|
|
||||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||||
# - optional, default: 5800
|
# - optional, default: 5800
|
||||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||||
# - it is automatically incremented for every model that uses it
|
# - it is automatically incremented for every model that uses it
|
||||||
startPort: 10001
|
startPort: 10001
|
||||||
|
|
||||||
|
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||||
|
# field
|
||||||
|
# - optional, default: false
|
||||||
|
# - when true, a stream of loading messages will be sent to the client in the
|
||||||
|
# reasoning field so chat UIs can show that loading is in progress.
|
||||||
|
# - see #366 for more details
|
||||||
|
sendLoadingState: true
|
||||||
|
|
||||||
|
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||||
|
# - optional, default: false
|
||||||
|
# - when true, model aliases will be output to the API model listing duplicating
|
||||||
|
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||||
|
includeAliasesInList: false
|
||||||
|
|
||||||
|
# globalTTL: the default TTL in seconds before unloading a model
|
||||||
|
# - optional, default: 0 (never automatically unload)
|
||||||
|
# - must be >= 0
|
||||||
|
globalTTL: 0
|
||||||
|
|
||||||
# macros: a dictionary of string substitutions
|
# macros: a dictionary of string substitutions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - macros are reusable snippets
|
# - macros are reusable snippets
|
||||||
@@ -136,6 +181,9 @@ startPort: 10001
|
|||||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||||
# - macro values can be numbers, bools, or strings
|
# - macro values can be numbers, bools, or strings
|
||||||
# - macros can contain other macros, but they must be defined before they are used
|
# - macros can contain other macros, but they must be defined before they are used
|
||||||
|
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||||
|
# - env macros are substituted first, before regular macros
|
||||||
|
# - if the env var is not set, config loading will fail with an error
|
||||||
macros:
|
macros:
|
||||||
# Example of a multi-line macro
|
# Example of a multi-line macro
|
||||||
"latest-llama": >
|
"latest-llama": >
|
||||||
@@ -148,6 +196,24 @@ macros:
|
|||||||
# but they must be previously declared.
|
# but they must be previously declared.
|
||||||
"default_args": "--ctx-size ${default_ctx}"
|
"default_args": "--ctx-size ${default_ctx}"
|
||||||
|
|
||||||
|
# Example of environment variable macros
|
||||||
|
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||||
|
# - useful for paths, secrets, or machine-specific configuration
|
||||||
|
"models_dir": "${env.HOME}/models"
|
||||||
|
|
||||||
|
# apiKeys: require an API key when making requests to inference endpoints
|
||||||
|
# - optional, default: []
|
||||||
|
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||||
|
# - each key is a non-empty string
|
||||||
|
apiKeys:
|
||||||
|
- "sk-hunter2"
|
||||||
|
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||||
|
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||||
|
|
||||||
|
# use environment variable macros to keep secrets out of the config
|
||||||
|
- "${env.API_KEY_1}"
|
||||||
|
- "${env.API_KEY_2}"
|
||||||
|
|
||||||
# models: a dictionary of model configurations
|
# models: a dictionary of model configurations
|
||||||
# - required
|
# - required
|
||||||
# - each key is the model's ID, used in API requests
|
# - each key is the model's ID, used in API requests
|
||||||
@@ -156,7 +222,7 @@ macros:
|
|||||||
# - below are examples of the all the settings a model can have
|
# - below are examples of the all the settings a model can have
|
||||||
models:
|
models:
|
||||||
# keys are the model names used in API requests
|
# keys are the model names used in API requests
|
||||||
"llama":
|
"gpt-oss-120b":
|
||||||
# macros: a dictionary of string substitutions specific to this model
|
# macros: a dictionary of string substitutions specific to this model
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - macros defined here override macros defined in the global macros section
|
# - macros defined here override macros defined in the global macros section
|
||||||
@@ -173,7 +239,7 @@ models:
|
|||||||
cmd: |
|
cmd: |
|
||||||
# ${latest-llama} is a macro that is defined above
|
# ${latest-llama} is a macro that is defined above
|
||||||
${latest-llama}
|
${latest-llama}
|
||||||
--model path/to/llama-8B-Q4_K_M.gguf
|
--model path/to/gpt-oss-120B.gguf
|
||||||
--ctx-size ${default_ctx}
|
--ctx-size ${default_ctx}
|
||||||
--temperature ${temp}
|
--temperature ${temp}
|
||||||
|
|
||||||
@@ -181,13 +247,13 @@ models:
|
|||||||
# - optional, default: empty string
|
# - optional, default: empty string
|
||||||
# - if set, it will be used in the v1/models API response
|
# - if set, it will be used in the v1/models API response
|
||||||
# - if not set, it will be omitted in the JSON model record
|
# - if not set, it will be omitted in the JSON model record
|
||||||
name: "llama 3.1 8B"
|
name: "gpt-oss 120B"
|
||||||
|
|
||||||
# description: a description for the model
|
# description: a description for the model
|
||||||
# - optional, default: empty string
|
# - optional, default: empty string
|
||||||
# - if set, it will be used in the v1/models API response
|
# - if set, it will be used in the v1/models API response
|
||||||
# - if not set, it will be omitted in the JSON model record
|
# - if not set, it will be omitted in the JSON model record
|
||||||
description: "A small but capable model used for quick testing"
|
description: "A thinking model from OpenAI"
|
||||||
|
|
||||||
# env: define an array of environment variables to inject into cmd's environment
|
# env: define an array of environment variables to inject into cmd's environment
|
||||||
# - optional, default: empty array
|
# - optional, default: empty array
|
||||||
@@ -202,14 +268,6 @@ models:
|
|||||||
# - if you use a custom port in cmd this *must* be set
|
# - if you use a custom port in cmd this *must* be set
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# aliases: alternative model names that this model configuration is used for
|
|
||||||
# - optional, default: empty array
|
|
||||||
# - aliases must be unique globally
|
|
||||||
# - useful for impersonating a specific model
|
|
||||||
aliases:
|
|
||||||
- "gpt-4o-mini"
|
|
||||||
- "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
# checkEndpoint: URL path to check if the server is ready
|
# checkEndpoint: URL path to check if the server is ready
|
||||||
# - optional, default: /health
|
# - optional, default: /health
|
||||||
# - endpoint is expected to return an HTTP 200 response
|
# - endpoint is expected to return an HTTP 200 response
|
||||||
@@ -218,8 +276,10 @@ models:
|
|||||||
checkEndpoint: /custom-endpoint
|
checkEndpoint: /custom-endpoint
|
||||||
|
|
||||||
# ttl: automatically unload the model after ttl seconds
|
# ttl: automatically unload the model after ttl seconds
|
||||||
# - optional, default: 0
|
# - optional, default: -1 (use global default)
|
||||||
# - ttl values must be a value greater than 0
|
# - ttl values must be a value greater than or equal to 0
|
||||||
|
# - a ttl of -1 will use the global TTL value as the default
|
||||||
|
# - a ttl of 0 will mean never unload
|
||||||
# - a value of 0 disables automatic unloading of the model
|
# - a value of 0 disables automatic unloading of the model
|
||||||
ttl: 60
|
ttl: 60
|
||||||
|
|
||||||
@@ -227,11 +287,11 @@ models:
|
|||||||
# - optional, default: ""
|
# - optional, default: ""
|
||||||
# - useful for when the upstream server expects a specific model name that
|
# - useful for when the upstream server expects a specific model name that
|
||||||
# is different from the model's ID
|
# is different from the model's ID
|
||||||
useModelName: "qwen:qwq"
|
useModelName: "openai/gpt-oss-120B"
|
||||||
|
|
||||||
# filters: a dictionary of filter settings
|
# filters: a dictionary of filter settings
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - only stripParams is currently supported
|
# - same capabilities as peer filters (stripParams, setParams)
|
||||||
filters:
|
filters:
|
||||||
# stripParams: a comma separated list of parameters to remove from the request
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
# - optional, default: ""
|
# - optional, default: ""
|
||||||
@@ -241,6 +301,43 @@ models:
|
|||||||
# - recommended to stick to sampling parameters
|
# - recommended to stick to sampling parameters
|
||||||
stripParams: "temperature, top_p, top_k"
|
stripParams: "temperature, top_p, top_k"
|
||||||
|
|
||||||
|
# setParams: a dictionary of parameters to set/override in requests
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - useful for enforcing specific parameter values
|
||||||
|
# - protected params like "model" cannot be overridden
|
||||||
|
# - values can be strings, numbers, booleans, arrays, or objects
|
||||||
|
# - always runs for the model
|
||||||
|
setParams:
|
||||||
|
# Example: enforce specific sampling parameters
|
||||||
|
temperature: 0.7
|
||||||
|
top_p: 0.9
|
||||||
|
|
||||||
|
# setParamsByID: a dictionary of parameters to set based the model ID
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - combine with aliases to create variant behaviour without reloading the model
|
||||||
|
# - parameters are set in the request body JSON
|
||||||
|
# - run after setParams so it will override any settings
|
||||||
|
# - protected params like "model" cannot be overridden
|
||||||
|
# - values can be strings, numbers, booleans, arrays, or objects
|
||||||
|
# - model aliases will be automatically created for each key
|
||||||
|
setParamsByID:
|
||||||
|
"${MODEL_ID}":
|
||||||
|
chat_template_kwargs:
|
||||||
|
reasoning_effort: medium
|
||||||
|
"${MODEL_ID}:high":
|
||||||
|
chat_template_kwargs:
|
||||||
|
reasoning_effort: high
|
||||||
|
"${MODEL_ID}:low":
|
||||||
|
chat_template_kwargs:
|
||||||
|
reasoning_effort: low
|
||||||
|
|
||||||
|
# aliases: alternative model names that this model configuration is used for
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - aliases must be unique globally
|
||||||
|
# - useful for impersonating a specific model
|
||||||
|
aliases:
|
||||||
|
- "gpt-4o-mini"
|
||||||
|
|
||||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - while metadata can contains complex types it is recommended to keep it simple
|
# - while metadata can contains complex types it is recommended to keep it simple
|
||||||
@@ -274,6 +371,26 @@ models:
|
|||||||
# - recommended to be omitted and the default used
|
# - recommended to be omitted and the default used
|
||||||
concurrencyLimit: 0
|
concurrencyLimit: 0
|
||||||
|
|
||||||
|
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||||
|
# - optional, default: undefined (use global setting)
|
||||||
|
sendLoadingState: false
|
||||||
|
|
||||||
|
# timeouts: configure proxy connection timeouts for this model
|
||||||
|
# - optional, defaults shown below
|
||||||
|
# - useful for models running on slower hardware that need longer timeouts
|
||||||
|
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
|
||||||
|
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
|
||||||
|
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
|
||||||
|
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
|
||||||
|
# - idleConn: idle connection timeout in seconds, default: 90 seconds
|
||||||
|
# - set any value to 0 to disable that timeout (not recommended)
|
||||||
|
timeouts:
|
||||||
|
connect: 30
|
||||||
|
keepalive: 0
|
||||||
|
responseHeader: 60
|
||||||
|
tlsHandshake: 10
|
||||||
|
idleConn: 90
|
||||||
|
|
||||||
# Unlisted model example:
|
# Unlisted model example:
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
# unlisted: boolean, true or false
|
# unlisted: boolean, true or false
|
||||||
@@ -305,68 +422,83 @@ models:
|
|||||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||||
cmdStop: docker stop ${MODEL_ID}
|
cmdStop: docker stop ${MODEL_ID}
|
||||||
|
|
||||||
# groups: a dictionary of group settings
|
# =============================================================================
|
||||||
# - optional, default: empty dictionary
|
# matrix: run concurrent models with a solver-based swap DSL
|
||||||
# - provides advanced controls over model swapping behaviour
|
# =============================================================================
|
||||||
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
|
||||||
# - model IDs must be defined in the Models section
|
|
||||||
# - a model can only be a member of one group
|
|
||||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
|
||||||
# - see issue #109 for details
|
|
||||||
#
|
#
|
||||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
# Note:
|
||||||
groups:
|
# A config must use either a matrix or legacy groups, not both. A configuration error
|
||||||
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
# will occur if both are defined. Configuration examples for legacy Groups can be found:
|
||||||
# to run a time across the whole llama-swap instance
|
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||||
"group1":
|
#
|
||||||
# swap: controls the model swapping behaviour in within the group
|
# The matrix declares valid combinations of models that can run concurrently.
|
||||||
# - optional, default: true
|
# When a model is requested, the solver finds the cheapest way to make it
|
||||||
# - true : only one model is allowed to run at a time
|
# available by evicting as few (and least costly) running models as possible.
|
||||||
# - false: all models can run together, no swapping
|
#
|
||||||
swap: true
|
# Solver behavior:
|
||||||
|
# 1. Request arrives for model X
|
||||||
|
# 2. If X is already running, forward immediately. Done.
|
||||||
|
# 3. Find all sets containing X
|
||||||
|
# 4. For each candidate set, compute cost: sum of evict_costs for
|
||||||
|
# every running model NOT in that set
|
||||||
|
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
||||||
|
# 6. Evict what needs to stop. Start X. Forward request.
|
||||||
|
#
|
||||||
|
# Subset semantics: a set [a, b, c] means any subset is valid.
|
||||||
|
# Only the requested model is started — others are not preloaded.
|
||||||
|
#
|
||||||
|
# A model not appearing in any set can only run alone.
|
||||||
|
#
|
||||||
|
matrix:
|
||||||
|
# vars: short names for models (alphanumeric, 1-8 chars)
|
||||||
|
# - required for sets and evict_costs settings
|
||||||
|
# - each entry is a short name to a real model ID. Do not use an alias
|
||||||
|
# - used to keep set DSL logic short and easier to read
|
||||||
|
# - sets and evict_costs only use identifiers defined in vars
|
||||||
|
vars:
|
||||||
|
g: gemma-model
|
||||||
|
q: qwen-model
|
||||||
|
m: mistral-model
|
||||||
|
v: voxtral-model
|
||||||
|
e: reranker-model
|
||||||
|
L: llama-70B
|
||||||
|
sd: stable-diffusion
|
||||||
|
|
||||||
# exclusive: controls how the group affects other groups
|
# evict_costs: relative cost of losing a running model (default: 1)
|
||||||
# - optional, default: true
|
evict_costs:
|
||||||
# - true: causes all other groups to unload when this group runs a model
|
v: 50 # vllm backend, slow cold start
|
||||||
# - false: does not affect other groups
|
L: 30 # 70B weights, slow to load
|
||||||
exclusive: true
|
|
||||||
|
|
||||||
# members references the models defined above
|
# sets: named sets of concurrent model combinations
|
||||||
# required
|
# Values are DSL strings with operators:
|
||||||
members:
|
# & AND (models run together)
|
||||||
- "llama"
|
# | OR (alternatives)
|
||||||
- "qwen-unlisted"
|
# () grouping
|
||||||
|
# +ref inline another set's expression
|
||||||
|
#
|
||||||
|
# Expansion examples:
|
||||||
|
# "L" → [L]
|
||||||
|
# "a & b" → [a, b]
|
||||||
|
# "a | b" → [a], [b]
|
||||||
|
# "(a | b) & c" → [a, c], [b, c]
|
||||||
|
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||||
|
# "+llms & v" → expands llms inline, then applies & v
|
||||||
|
sets:
|
||||||
|
# LLM + TTS: switching between g/q/m won't evict v
|
||||||
|
# expands to: [g,v], [q,v], [m,v]
|
||||||
|
standard: "(g | q | m) & v"
|
||||||
|
|
||||||
# Example:
|
# LLM + TTS + reranker
|
||||||
# - in group2 all models can run at the same time
|
# expands to: [g,v,e], [q,v,e]
|
||||||
# - when a different group is loaded it causes all running models in this group to unload
|
with_rerank: "(g | q) & v & e"
|
||||||
"group2":
|
|
||||||
swap: false
|
|
||||||
|
|
||||||
# exclusive: false does not unload other groups when a model in group2 is requested
|
# LLM + image generation, no TTS
|
||||||
# - the models in group2 will be loaded but will not unload any other groups
|
# expands to: [g,sd], [q,sd]
|
||||||
exclusive: false
|
creative: "(g | q) & sd"
|
||||||
members:
|
|
||||||
- "docker-llama"
|
|
||||||
- "modelA"
|
|
||||||
- "modelB"
|
|
||||||
|
|
||||||
# Example:
|
# 70B model uses all GPUs, can only run alone
|
||||||
# - a persistent group, prevents other groups from unloading it
|
# expands to: [L]
|
||||||
"forever":
|
full: "L"
|
||||||
# persistent: prevents over groups from unloading the models in this group
|
|
||||||
# - optional, default: false
|
|
||||||
# - does not affect individual model behaviour
|
|
||||||
persistent: true
|
|
||||||
|
|
||||||
# set swap/exclusive to false to prevent swapping inside the group
|
|
||||||
# and the unloading of other groups
|
|
||||||
swap: false
|
|
||||||
exclusive: false
|
|
||||||
members:
|
|
||||||
- "forever-modelA"
|
|
||||||
- "forever-modelB"
|
|
||||||
- "forever-modelc"
|
|
||||||
|
|
||||||
# hooks: a dictionary of event triggers and actions
|
# hooks: a dictionary of event triggers and actions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
@@ -383,4 +515,68 @@ hooks:
|
|||||||
# otherwise models will be loaded and swapped out
|
# otherwise models will be loaded and swapped out
|
||||||
preload:
|
preload:
|
||||||
- "llama"
|
- "llama"
|
||||||
|
|
||||||
|
# peers: a dictionary of remote peers and models they provide
|
||||||
|
# - optional, default empty dictionary
|
||||||
|
# - peers can be another llama-swap
|
||||||
|
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||||
|
peers:
|
||||||
|
# keys is the peer'd ID
|
||||||
|
llama-swap-peer:
|
||||||
|
# proxy: a valid base URL to proxy requests to
|
||||||
|
# - required
|
||||||
|
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||||
|
proxy: http://192.168.1.23
|
||||||
|
# models: a list of models served by the peer
|
||||||
|
# - required
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
- model_b
|
||||||
|
- embeddings/model_c
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
# apiKey: a string key to be injected into the request
|
||||||
|
# - optional, default: ""
|
||||||
|
# - if blank, no key will be added to the request
|
||||||
|
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||||
|
# - can be a string or a macro
|
||||||
|
apiKey: ${env.OPENROUTER_API_KEY}
|
||||||
|
models:
|
||||||
|
- meta-llama/llama-3.1-8b-instruct
|
||||||
|
- qwen/qwen3-235b-a22b-2507
|
||||||
|
- deepseek/deepseek-v3.2
|
||||||
|
- z-ai/glm-4.7
|
||||||
|
- moonshotai/kimi-k2-0905
|
||||||
|
- minimax/minimax-m2.1
|
||||||
|
# timeouts: configure proxy connection timeouts for this peer
|
||||||
|
# - optional, defaults shown below
|
||||||
|
# - useful when the peer runs on slower hardware
|
||||||
|
# - set any value to 0 to disable that timeout (not recommended)
|
||||||
|
timeouts:
|
||||||
|
connect: 30
|
||||||
|
keepalive: 30
|
||||||
|
responseHeader: 60
|
||||||
|
tlsHandshake: 10
|
||||||
|
idleConn: 90
|
||||||
|
|
||||||
|
# filters: a dictionary of filter settings for peer requests
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - same capabilities as model filters (stripParams, setParams)
|
||||||
|
filters:
|
||||||
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
|
# - optional, default: ""
|
||||||
|
# - useful for removing parameters that the peer doesn't support
|
||||||
|
# - the `model` parameter can never be removed
|
||||||
|
stripParams: "temperature, top_p"
|
||||||
|
|
||||||
|
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - useful for injecting provider-specific settings like data retention policies
|
||||||
|
# - protected params like "model" cannot be overridden
|
||||||
|
# - values can be strings, numbers, booleans, arrays, or objects
|
||||||
|
setParams:
|
||||||
|
# Example: enforce zero-data-retention for OpenRouter
|
||||||
|
provider:
|
||||||
|
data_collection: "deny"
|
||||||
|
zdr: true
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
## Container Security
|
||||||
|
|
||||||
|
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
|
||||||
|
|
||||||
|
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
|
||||||
|
|
||||||
|
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
|
||||||
|
|
||||||
|
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
module github.com/mostlygeek/llama-swap
|
module github.com/mostlygeek/llama-swap
|
||||||
|
|
||||||
go 1.25.4
|
go 1.26.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/billziss-gh/golib v0.2.0
|
github.com/billziss-gh/golib v0.2.0
|
||||||
github.com/fsnotify/fsnotify v1.9.0
|
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
|
github.com/klauspost/compress v1.18.5
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
|
||||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
@@ -34,6 +32,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||||
|
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||||
|
|||||||
@@ -9,14 +9,15 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/configwatcher"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -79,6 +80,17 @@ func main() {
|
|||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Reload signals (SIGHUP on POSIX, none on Windows — Windows does not
|
||||||
|
// deliver SIGHUP). Always wired up so `kill -HUP` works regardless of
|
||||||
|
// --watch-config.
|
||||||
|
reloadChan := make(chan os.Signal, 1)
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
signal.Notify(reloadChan, syscall.SIGHUP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context that bounds the lifetime of background watcher goroutines.
|
||||||
|
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
// Create server with initial handler
|
// Create server with initial handler
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: *listenStr,
|
Addr: *listenStr,
|
||||||
@@ -121,52 +133,45 @@ func main() {
|
|||||||
// load the initial proxy manager
|
// load the initial proxy manager
|
||||||
reloadProxyManager()
|
reloadProxyManager()
|
||||||
debouncedReload := debounce(time.Second, reloadProxyManager)
|
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||||
if *watchConfig {
|
|
||||||
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
|
||||||
if e.ReloadingState == proxy.ReloadingStateStart {
|
|
||||||
debouncedReload()
|
|
||||||
}
|
|
||||||
})()
|
|
||||||
|
|
||||||
fmt.Println("Watching Configuration for changes")
|
// Listen for ConfigFileChangedEvent unconditionally so SIGHUP and the
|
||||||
|
// poll-based watcher both feed the same debounced reload pipeline. The
|
||||||
|
// UI also listens for the matching ReloadingStateEnd emitted from
|
||||||
|
// reloadProxyManager.
|
||||||
|
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||||
|
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||||
|
debouncedReload()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
// SIGHUP (or platform-equivalent) → reload. Back-to-back signals collapse
|
||||||
|
// to one reload via the debounce window, which is the desired behavior.
|
||||||
|
go func() {
|
||||||
|
for range reloadChan {
|
||||||
|
fmt.Println("Received reload signal, reloading configuration")
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if *watchConfig {
|
||||||
go func() {
|
go func() {
|
||||||
absConfigPath, err := filepath.Abs(*configPath)
|
absConfigPath, err := filepath.Abs(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
watcher, err := fsnotify.NewWatcher()
|
fmt.Println("Watching configuration for changes (poll-based, 2s interval)")
|
||||||
if err != nil {
|
(&configwatcher.Watcher{
|
||||||
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
Path: absConfigPath,
|
||||||
return
|
Interval: configwatcher.DefaultInterval,
|
||||||
}
|
OnChange: func() {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
configDir := filepath.Dir(absConfigPath)
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
err = watcher.Add(configDir)
|
})
|
||||||
if err != nil {
|
},
|
||||||
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
}).Run(watcherCtx)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer watcher.Close()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case changeEvent := <-watcher.Events:
|
|
||||||
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
|
||||||
event.Emit(proxy.ConfigFileChangedEvent{
|
|
||||||
ReloadingState: proxy.ReloadingStateStart,
|
|
||||||
})
|
|
||||||
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
|
||||||
// the change for k8s configmap
|
|
||||||
event.Emit(proxy.ConfigFileChangedEvent{
|
|
||||||
ReloadingState: proxy.ReloadingStateStart,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
case err := <-watcher.Errors:
|
|
||||||
log.Printf("File watcher error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,6 +179,7 @@ func main() {
|
|||||||
go func() {
|
go func() {
|
||||||
sig := <-sigChan
|
sig := <-sigChan
|
||||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||||
|
watcherCancel()
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const DEFAULT_GROUP_ID = "(default)"
|
const DEFAULT_GROUP_ID = "(default)"
|
||||||
|
const (
|
||||||
|
LogToStdoutProxy = "proxy"
|
||||||
|
LogToStdoutUpstream = "upstream"
|
||||||
|
LogToStdoutBoth = "both"
|
||||||
|
LogToStdoutNone = "none"
|
||||||
|
)
|
||||||
|
|
||||||
type MacroEntry struct {
|
type MacroEntry struct {
|
||||||
Name string
|
Name string
|
||||||
@@ -81,6 +87,7 @@ type GroupConfig struct {
|
|||||||
var (
|
var (
|
||||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||||
|
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||||
)
|
)
|
||||||
|
|
||||||
// set default values for GroupConfig
|
// set default values for GroupConfig
|
||||||
@@ -114,11 +121,20 @@ type Config struct {
|
|||||||
LogRequests bool `yaml:"logRequests"`
|
LogRequests bool `yaml:"logRequests"`
|
||||||
LogLevel string `yaml:"logLevel"`
|
LogLevel string `yaml:"logLevel"`
|
||||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||||
|
LogToStdout string `yaml:"logToStdout"`
|
||||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||||
|
CaptureBuffer int `yaml:"captureBuffer"`
|
||||||
|
GlobalTTL int `yaml:"globalTTL"`
|
||||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||||
|
|
||||||
|
// swap matrix: solver-based alternative to groups
|
||||||
|
Matrix *MatrixConfig `yaml:"matrix"`
|
||||||
|
|
||||||
|
// populated during validation when matrix is configured
|
||||||
|
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||||
|
|
||||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||||
Macros MacroList `yaml:"macros"`
|
Macros MacroList `yaml:"macros"`
|
||||||
|
|
||||||
@@ -136,6 +152,12 @@ type Config struct {
|
|||||||
|
|
||||||
// present aliases to /v1/models OpenAI API listing
|
// present aliases to /v1/models OpenAI API listing
|
||||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||||
|
|
||||||
|
// support API keys, see issue #433, #50, #251
|
||||||
|
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||||
|
|
||||||
|
// support remote peers, see issue #433, #296
|
||||||
|
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
@@ -170,22 +192,31 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return Config{}, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
|
yamlStr := string(data)
|
||||||
|
|
||||||
// default configuration values
|
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||||
|
// This is safe because env values are simple strings without YAML formatting
|
||||||
|
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal into full Config with defaults
|
||||||
config := Config{
|
config := Config{
|
||||||
HealthCheckTimeout: 120,
|
HealthCheckTimeout: 120,
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
LogTimeFormat: "",
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
MetricsMaxInMemory: 1000,
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
|
GlobalTTL: 0,
|
||||||
}
|
}
|
||||||
err = yaml.Unmarshal(data, &config)
|
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
if config.HealthCheckTimeout < 15 {
|
||||||
// set a minimum of 15 seconds
|
|
||||||
config.HealthCheckTimeout = 15
|
config.HealthCheckTimeout = 15
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,6 +224,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.GlobalTTL < 0 {
|
||||||
|
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.LogToStdout {
|
||||||
|
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||||
|
}
|
||||||
|
|
||||||
// Populate the aliases map
|
// Populate the aliases map
|
||||||
config.aliases = make(map[string]string)
|
config.aliases = make(map[string]string)
|
||||||
for modelName, modelConfig := range config.Models {
|
for modelName, modelConfig := range config.Models {
|
||||||
@@ -204,55 +245,55 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* check macro constraint rules:
|
// Validate global macros
|
||||||
|
|
||||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
|
||||||
- names must be less than 64 characters (no reason, just cause)
|
|
||||||
- name can not be any reserved macros: PORT, MODEL_ID
|
|
||||||
- macro values must be less than 1024 characters
|
|
||||||
*/
|
|
||||||
for _, macro := range config.Macros {
|
for _, macro := range config.Macros {
|
||||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
return Config{}, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get and sort all model IDs first, makes testing more consistent
|
// Get and sort all model IDs for consistent port assignment
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
for modelId := range config.Models {
|
for modelId := range config.Models {
|
||||||
modelIds = append(modelIds, modelId)
|
modelIds = append(modelIds, modelId)
|
||||||
}
|
}
|
||||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
sort.Strings(modelIds)
|
||||||
|
|
||||||
nextPort := config.StartPort
|
nextPort := config.StartPort
|
||||||
for _, modelId := range modelIds {
|
for _, modelId := range modelIds {
|
||||||
modelConfig := config.Models[modelId]
|
modelConfig := config.Models[modelId]
|
||||||
|
|
||||||
// Strip comments from command fields before macro expansion
|
// Strip comments from command fields
|
||||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
// validate model macros
|
// set model TTL to globalTTL it is the default value
|
||||||
|
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||||
|
modelConfig.UnloadAfter = config.GlobalTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelConfig.UnloadAfter < 0 {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate model macros
|
||||||
for _, macro := range modelConfig.Macros {
|
for _, macro := range modelConfig.Macros {
|
||||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge global config and model macros. Model macros take precedence
|
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
|
||||||
// Add global macros first
|
|
||||||
mergedMacros = append(mergedMacros, config.Macros...)
|
mergedMacros = append(mergedMacros, config.Macros...)
|
||||||
|
|
||||||
// Add model macros (can override global)
|
// Add model macros (override globals with same name)
|
||||||
for _, entry := range modelConfig.Macros {
|
for _, entry := range modelConfig.Macros {
|
||||||
// Remove any existing global macro with same name
|
|
||||||
found := false
|
found := false
|
||||||
for i, existing := range mergedMacros {
|
for i, existing := range mergedMacros {
|
||||||
if existing.Name == entry.Name {
|
if existing.Name == entry.Name {
|
||||||
mergedMacros[i] = entry // Override
|
mergedMacros[i] = entry
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -262,23 +303,40 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
// Substitute remaining macros in model fields (LIFO order)
|
||||||
// This allows later macros to reference earlier ones
|
|
||||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
entry := mergedMacros[i]
|
entry := mergedMacros[i]
|
||||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
// Substitute in command fields
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||||
|
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||||
|
|
||||||
// Substitute in metadata (recursive)
|
// Substitute macros in SetParamsByID keys and values
|
||||||
|
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||||
|
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||||
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
|
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||||
|
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
newParamMap, ok := newValAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||||
|
}
|
||||||
|
newSetParamsByID[newKey] = newParamMap
|
||||||
|
}
|
||||||
|
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitute in metadata (type-preserving)
|
||||||
if len(modelConfig.Metadata) > 0 {
|
if len(modelConfig.Metadata) > 0 {
|
||||||
var err error
|
|
||||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
@@ -287,29 +345,25 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final pass: check if PORT macro is needed after macro expansion
|
// Handle PORT macro - only allocate if cmd uses it
|
||||||
// ${PORT} is a resource on the local machine so a new port is only allocated
|
|
||||||
// if it is required in either cmd or proxy keys
|
|
||||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||||
if cmdHasPort || proxyHasPort { // either has it
|
if cmdHasPort || proxyHasPort {
|
||||||
if !cmdHasPort && proxyHasPort { // but both don't have it
|
if !cmdHasPort && proxyHasPort {
|
||||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add PORT macro and substitute it
|
|
||||||
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
|
||||||
macroSlug := "${PORT}"
|
macroSlug := "${PORT}"
|
||||||
macroStr := fmt.Sprintf("%v", nextPort)
|
macroStr := fmt.Sprintf("%v", nextPort)
|
||||||
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||||
|
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||||
|
|
||||||
// Substitute PORT in metadata
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
if len(modelConfig.Metadata) > 0 {
|
||||||
var err error
|
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||||
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
}
|
}
|
||||||
@@ -319,13 +373,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
nextPort++
|
nextPort++
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure there are no unknown macros that have not been replaced
|
// Validate no unknown macros remain
|
||||||
fieldMap := map[string]string{
|
fieldMap := map[string]string{
|
||||||
"cmd": modelConfig.Cmd,
|
"cmd": modelConfig.Cmd,
|
||||||
"cmdStop": modelConfig.CmdStop,
|
"cmdStop": modelConfig.CmdStop,
|
||||||
"proxy": modelConfig.Proxy,
|
"proxy": modelConfig.Proxy,
|
||||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||||
|
"name": modelConfig.Name,
|
||||||
|
"description": modelConfig.Description,
|
||||||
}
|
}
|
||||||
|
|
||||||
for fieldName, fieldValue := range fieldMap {
|
for fieldName, fieldValue := range fieldMap {
|
||||||
@@ -333,62 +389,94 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
macroName := match[1]
|
macroName := match[1]
|
||||||
if macroName == "PID" && fieldName == "cmdStop" {
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
continue // this is ok, has to be replaced by process later
|
continue // replaced at runtime
|
||||||
}
|
}
|
||||||
// Reserved macros are always valid (they should have been substituted already)
|
|
||||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||||
}
|
}
|
||||||
// Any other macro is unknown
|
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for unknown macros in metadata
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
if len(modelConfig.Metadata) > 0 {
|
||||||
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||||
return Config{}, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the proxy URL.
|
// Validate SetParamsByID keys and values
|
||||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
return Config{}, fmt.Errorf(
|
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||||
"model %s: invalid proxy URL: %w", modelId, err,
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||||
)
|
}
|
||||||
|
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||||
|
for key := range modelConfig.Filters.SetParamsByID {
|
||||||
|
if key == modelId {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := config.Models[key]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||||
|
}
|
||||||
|
if existingModel, exists := config.aliases[key]; exists {
|
||||||
|
if existingModel != modelId {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||||
|
}
|
||||||
|
continue // already registered as explicit alias for this model
|
||||||
|
}
|
||||||
|
config.aliases[key] = modelId
|
||||||
|
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if sendLoadingState is nil, set it to the global config value
|
|
||||||
// see #366
|
|
||||||
if modelConfig.SendLoadingState == nil {
|
if modelConfig.SendLoadingState == nil {
|
||||||
v := config.SendLoadingState // copy it
|
v := config.SendLoadingState
|
||||||
modelConfig.SendLoadingState = &v
|
modelConfig.SendLoadingState = &v
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Models[modelId] = modelConfig
|
config.Models[modelId] = modelConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
// groups XOR matrix
|
||||||
// check that members are all unique in the groups
|
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||||
for groupID, groupConfig := range config.Groups {
|
}
|
||||||
prevSet := make(map[string]bool)
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
// Check for duplicates within this group
|
|
||||||
if _, found := prevSet[member]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
|
||||||
}
|
|
||||||
prevSet[member] = true
|
|
||||||
|
|
||||||
// Check if member is used in another group
|
if config.Matrix != nil {
|
||||||
if existingGroup, exists := memberUsage[member]; exists {
|
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||||
|
}
|
||||||
|
config.ExpandedSets = expandedSets
|
||||||
|
} else {
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
|
||||||
|
// Validate group members
|
||||||
|
memberUsage := make(map[string]string)
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
}
|
}
|
||||||
memberUsage[member] = groupID
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean up hooks preload
|
// Clean up hooks preload
|
||||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
var toPreload []string
|
var toPreload []string
|
||||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||||
@@ -400,10 +488,56 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
toPreload = append(toPreload, real)
|
toPreload = append(toPreload, real)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Hooks.OnStartup.Preload = toPreload
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate API keys (env macros already substituted at string level)
|
||||||
|
for i, apikey := range config.RequiredAPIKeys {
|
||||||
|
if apikey == "" {
|
||||||
|
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||||
|
}
|
||||||
|
if strings.Contains(apikey, " ") {
|
||||||
|
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||||
|
}
|
||||||
|
config.RequiredAPIKeys[i] = apikey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process peers with global macro substitution
|
||||||
|
for peerName, peerConfig := range config.Peers {
|
||||||
|
// Substitute global macros (LIFO order)
|
||||||
|
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||||
|
entry := config.Macros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||||
|
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in setParams (type-preserving)
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||||
|
}
|
||||||
|
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Peers[peerName] = peerConfig
|
||||||
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -534,20 +668,26 @@ func validateMacro(name string, value any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||||
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
func validateNestedForUnknownMacros(value any, context string) error {
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case string:
|
case string:
|
||||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
macroName := match[1]
|
macroName := match[1]
|
||||||
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||||
|
}
|
||||||
|
// Check for unsubstituted env macros
|
||||||
|
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range envMatches {
|
||||||
|
varName := match[1]
|
||||||
|
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
for _, val := range v {
|
for _, val := range v {
|
||||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -555,7 +695,7 @@ func validateMetadataForUnknownMacros(value any, modelId string) error {
|
|||||||
|
|
||||||
case []any:
|
case []any:
|
||||||
for _, val := range v {
|
for _, val := range v {
|
||||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -614,3 +754,67 @@ func substituteMacroInValue(value any, macroName string, macroValue any) (any, e
|
|||||||
return value, nil
|
return value, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||||
|
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||||
|
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||||
|
// (which strips comments) and only checking the comment-free version for macros.
|
||||||
|
func substituteEnvMacros(s string) (string, error) {
|
||||||
|
// Unmarshal and remarshal to strip YAML comments
|
||||||
|
var raw any
|
||||||
|
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||||
|
// If YAML is invalid, fall back to scanning the original string
|
||||||
|
// so the user gets the env var error rather than a confusing YAML parse error
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
clean, err := yaml.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return substituteEnvMacrosInString(s, string(clean))
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||||
|
// them in target. This separation allows scanning comment-free YAML while
|
||||||
|
// substituting in the original string.
|
||||||
|
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||||
|
result := target
|
||||||
|
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
fullMatch := match[0] // ${env.VAR_NAME}
|
||||||
|
varName := match[1] // VAR_NAME
|
||||||
|
|
||||||
|
value, exists := os.LookupEnv(varName)
|
||||||
|
if !exists {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the value for safe YAML substitution
|
||||||
|
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = strings.ReplaceAll(result, fullMatch, value)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||||
|
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||||
|
// for compatibility with double-quoted YAML strings.
|
||||||
|
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||||
|
// Reject values that would break YAML structure regardless of quoting context
|
||||||
|
if strings.ContainsAny(value, "\n\r\x00") {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||||
|
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||||
|
// In double-quoted contexts, they are interpreted correctly.
|
||||||
|
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||||
|
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -163,9 +163,19 @@ groups:
|
|||||||
|
|
||||||
modelLoadingState := false
|
modelLoadingState := false
|
||||||
|
|
||||||
|
defaultTimeout := TimeoutsConfig{
|
||||||
|
Connect: 30,
|
||||||
|
KeepAlive: 30,
|
||||||
|
ResponseHeader: 0,
|
||||||
|
TLSHandshake: 10,
|
||||||
|
ExpectContinue: 1,
|
||||||
|
IdleConn: 90,
|
||||||
|
}
|
||||||
|
|
||||||
expected := Config{
|
expected := Config{
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
LogTimeFormat: "",
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
Macros: MacroList{
|
Macros: MacroList{
|
||||||
{"svr-path", "path/to/server"},
|
{"svr-path", "path/to/server"},
|
||||||
@@ -186,6 +196,7 @@ groups:
|
|||||||
Name: "Model 1",
|
Name: "Model 1",
|
||||||
Description: "This is model 1",
|
Description: "This is model 1",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
|
Timeouts: defaultTimeout,
|
||||||
},
|
},
|
||||||
"model2": {
|
"model2": {
|
||||||
Cmd: "path/to/server --arg1 one",
|
Cmd: "path/to/server --arg1 one",
|
||||||
@@ -194,6 +205,7 @@ groups:
|
|||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
|
Timeouts: defaultTimeout,
|
||||||
},
|
},
|
||||||
"model3": {
|
"model3": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -202,6 +214,7 @@ groups:
|
|||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
|
Timeouts: defaultTimeout,
|
||||||
},
|
},
|
||||||
"model4": {
|
"model4": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -210,10 +223,12 @@ groups:
|
|||||||
Aliases: []string{},
|
Aliases: []string{},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
|
Timeouts: defaultTimeout,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
MetricsMaxInMemory: 1000,
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
Profiles: map[string][]string{
|
Profiles: map[string][]string{
|
||||||
"test": {"model1", "model2"},
|
"test": {"model1", "model2"},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||||
@@ -761,3 +762,785 @@ models:
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
content: `apiKeys: [""]`,
|
||||||
|
expectedErr: "empty api key found in apiKeys",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blank spaces only",
|
||||||
|
content: `apiKeys: [" "]`,
|
||||||
|
expectedErr: "api key cannot contain spaces: ` `",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "contains leading space",
|
||||||
|
content: `apiKeys: [" key123"]`,
|
||||||
|
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "contains trailing space",
|
||||||
|
content: `apiKeys: ["key123 "]`,
|
||||||
|
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "contains middle space",
|
||||||
|
content: `apiKeys: ["key 123"]`,
|
||||||
|
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty in list with valid keys",
|
||||||
|
content: `apiKeys: ["valid-key", "", "another-key"]`,
|
||||||
|
expectedErr: "empty api key found in apiKeys",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Equal(t, tt.expectedErr, err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_APIKeys_EnvMacros(t *testing.T) {
|
||||||
|
t.Run("env substitution in apiKeys", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_API_KEY", "secret-key-123")
|
||||||
|
|
||||||
|
content := `apiKeys: ["${env.TEST_API_KEY}"]`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"secret-key-123"}, config.RequiredAPIKeys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple env substitutions in apiKeys", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_API_KEY_1", "key-one")
|
||||||
|
t.Setenv("TEST_API_KEY_2", "key-two")
|
||||||
|
|
||||||
|
content := `apiKeys: ["${env.TEST_API_KEY_1}", "${env.TEST_API_KEY_2}", "static-key"]`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"key-one", "key-two", "static-key"}, config.RequiredAPIKeys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing env var in apiKeys", func(t *testing.T) {
|
||||||
|
content := `apiKeys: ["${env.NONEXISTENT_API_KEY}"]`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
// With string-level env substitution, error only includes var name
|
||||||
|
assert.Contains(t, err.Error(), "NONEXISTENT_API_KEY")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env substitution results in empty key", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_EMPTY_KEY", "")
|
||||||
|
|
||||||
|
content := `apiKeys: ["${env.TEST_EMPTY_KEY}"]`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, "empty api key found in apiKeys", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_GlobalTTL(t *testing.T) {
|
||||||
|
t.Run("globalTTL sets default for models", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
globalTTL: 300
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: server --port ${PORT}
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 300, config.GlobalTTL)
|
||||||
|
assert.Equal(t, 300, config.Models["model1"].UnloadAfter)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model ttl=0 overrides globalTTL", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
globalTTL: 300
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: server --port ${PORT}
|
||||||
|
ttl: 0
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model explicit ttl overrides globalTTL", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
globalTTL: 300
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: server --port ${PORT}
|
||||||
|
ttl: 600
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 600, config.Models["model1"].UnloadAfter)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("globalTTL defaults to 0", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: server --port ${PORT}
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, config.GlobalTTL)
|
||||||
|
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negative globalTTL rejected", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
globalTTL: -1
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: server --port ${PORT}
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "globalTTL must be >= 0")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_EnvMacros(t *testing.T) {
|
||||||
|
t.Run("basic env substitution in cmd", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_MODEL_PATH", "/opt/models")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "${env.TEST_MODEL_PATH}/llama-server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "/opt/models/llama-server", config.Models["test"].Cmd)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env substitution in multiple fields", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_HOST", "myserver")
|
||||||
|
t.Setenv("TEST_PORT", "9999")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server --host ${env.TEST_HOST}"
|
||||||
|
proxy: "http://${env.TEST_HOST}:${env.TEST_PORT}"
|
||||||
|
checkEndpoint: "http://${env.TEST_HOST}/health"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "server --host myserver", config.Models["test"].Cmd)
|
||||||
|
assert.Equal(t, "http://myserver:9999", config.Models["test"].Proxy)
|
||||||
|
assert.Equal(t, "http://myserver/health", config.Models["test"].CheckEndpoint)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env in global macro value", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_BASE_PATH", "/usr/local")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
SERVER_PATH: "${env.TEST_BASE_PATH}/bin/server"
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "${SERVER_PATH} --port 8080"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "/usr/local/bin/server --port 8080", config.Models["test"].Cmd)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env in model-level macro value", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_MODEL_DIR", "/models/llama")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
MODEL_FILE: "${env.TEST_MODEL_DIR}/model.gguf"
|
||||||
|
cmd: "server --model ${MODEL_FILE}"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "server --model /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env in metadata", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_API_KEY", "secret123")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
metadata:
|
||||||
|
api_key: "${env.TEST_API_KEY}"
|
||||||
|
nested:
|
||||||
|
key: "${env.TEST_API_KEY}"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "secret123", config.Models["test"].Metadata["api_key"])
|
||||||
|
nested := config.Models["test"].Metadata["nested"].(map[string]any)
|
||||||
|
assert.Equal(t, "secret123", nested["key"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env in filters.stripParams", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_STRIP_PARAMS", "temperature,top_p")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
filters:
|
||||||
|
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "temperature,top_p", config.Models["test"].Filters.StripParams)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env in cmdStop", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_KILL_SIGNAL", "SIGTERM")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server --port ${PORT}"
|
||||||
|
cmdStop: "kill -${env.TEST_KILL_SIGNAL} ${PID}"
|
||||||
|
proxy: "http://localhost:${PORT}"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, config.Models["test"].CmdStop, "-SIGTERM")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing env var returns error", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "${env.UNDEFINED_VAR_12345}/server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED_VAR_12345")
|
||||||
|
assert.Contains(t, err.Error(), "not set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing env var in global macro", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
PATH: "${env.UNDEFINED_GLOBAL_VAR}"
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED_GLOBAL_VAR")
|
||||||
|
assert.Contains(t, err.Error(), "not set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing env var in model macro", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
MY_PATH: "${env.UNDEFINED_MODEL_VAR}"
|
||||||
|
cmd: "server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED_MODEL_VAR")
|
||||||
|
assert.Contains(t, err.Error(), "not set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing env var in metadata", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
metadata:
|
||||||
|
key: "${env.UNDEFINED_META_VAR}"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED_META_VAR")
|
||||||
|
assert.Contains(t, err.Error(), "not set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env combined with regular macros", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_ROOT", "/data")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
MODEL_BASE: "${env.TEST_ROOT}/models"
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server --model ${MODEL_BASE}/${MODEL_ID}.gguf"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "server --model /data/models/test.gguf", config.Models["test"].Cmd)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple env vars in same string", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_USER", "admin")
|
||||||
|
t.Setenv("TEST_PASS", "secret")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server --auth ${env.TEST_USER}:${env.TEST_PASS}"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "server --auth admin:secret", config.Models["test"].Cmd)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env value with newline is rejected", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_MULTILINE", "line1\nline2")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server --config ${env.TEST_MULTILINE}"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), "TEST_MULTILINE")
|
||||||
|
assert.Contains(t, err.Error(), "newlines")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env value with carriage return is rejected", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_CR", "line1\rline2")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server --config ${env.TEST_CR}"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), "TEST_CR")
|
||||||
|
assert.Contains(t, err.Error(), "newlines")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env value with quotes is escaped for YAML", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_QUOTED", `value with "quotes"`)
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server --arg \"${env.TEST_QUOTED}\""
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// Quotes are escaped before YAML parsing, then YAML unescapes them
|
||||||
|
// Final result preserves the original value with quotes
|
||||||
|
assert.Contains(t, config.Models["test"].Cmd, `"quotes"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env value with backslash is escaped for YAML", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_BACKSLASH", `path\to\file`)
|
||||||
|
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server --path \"${env.TEST_BACKSLASH}\""
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// Backslashes are escaped before YAML parsing, then YAML unescapes them
|
||||||
|
// Final result preserves the original value with backslashes
|
||||||
|
assert.Contains(t, config.Models["test"].Cmd, `path\to\file`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_PeerApiKey_EnvMacros(t *testing.T) {
|
||||||
|
t.Run("env substitution in peer apiKey", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_PEER_API_KEY", "sk-peer-secret-123")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: "${env.TEST_PEER_API_KEY}"
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "sk-peer-secret-123", config.Peers["openrouter"].ApiKey)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing env var in peer apiKey", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: "${env.NONEXISTENT_PEER_KEY}"
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
// With string-level env substitution, error only includes var name
|
||||||
|
assert.Contains(t, err.Error(), "NONEXISTENT_PEER_KEY")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("static apiKey unchanged", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: sk-static-key
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "sk-static-key", config.Peers["openrouter"].ApiKey)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple peers with env apiKeys", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_PEER_KEY_1", "key-one")
|
||||||
|
t.Setenv("TEST_PEER_KEY_2", "key-two")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
peers:
|
||||||
|
peer1:
|
||||||
|
proxy: https://peer1.example.com
|
||||||
|
apiKey: "${env.TEST_PEER_KEY_1}"
|
||||||
|
models:
|
||||||
|
- model-a
|
||||||
|
peer2:
|
||||||
|
proxy: https://peer2.example.com
|
||||||
|
apiKey: "${env.TEST_PEER_KEY_2}"
|
||||||
|
models:
|
||||||
|
- model-b
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "key-one", config.Peers["peer1"].ApiKey)
|
||||||
|
assert.Equal(t, "key-two", config.Peers["peer2"].ApiKey)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("global macro substitution in peer apiKey", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
API_KEY: sk-from-global-macro
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: "${API_KEY}"
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "sk-from-global-macro", config.Peers["openrouter"].ApiKey)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("global macro in peer filters.stripParams", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
STRIP_LIST: "temperature, top_p"
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
filters:
|
||||||
|
stripParams: "${STRIP_LIST}"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "temperature, top_p", config.Peers["openrouter"].Filters.StripParams)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("global macro in peer filters.setParams", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
MAX_TOKENS: 4096
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
filters:
|
||||||
|
setParams:
|
||||||
|
max_tokens: "${MAX_TOKENS}"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 4096, config.Peers["openrouter"].Filters.SetParams["max_tokens"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env macro in peer filters.setParams", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_RETENTION_POLICY", "deny")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
filters:
|
||||||
|
setParams:
|
||||||
|
data_collection: "${env.TEST_RETENTION_POLICY}"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "deny", config.Peers["openrouter"].Filters.SetParams["data_collection"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env macro in peer filters.stripParams", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_STRIP_PARAMS", "frequency_penalty, presence_penalty")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
filters:
|
||||||
|
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "frequency_penalty, presence_penalty", config.Peers["openrouter"].Filters.StripParams)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown macro in peer apiKey fails", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: "${UNDEFINED_MACRO}"
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "peers.openrouter.apiKey")
|
||||||
|
assert.Contains(t, err.Error(), "unknown macro")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown macro in peer filters.setParams fails", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
peers:
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
models:
|
||||||
|
- llama-3.1-8b
|
||||||
|
filters:
|
||||||
|
setParams:
|
||||||
|
value: "${UNDEFINED_MACRO}"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "peers.openrouter.filters.setParams")
|
||||||
|
assert.Contains(t, err.Error(), "unknown macro")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env macros in comments are ignored", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
# apiKeys:
|
||||||
|
# - "${env.COMMENTED_OUT_KEY_1}"
|
||||||
|
# - "${env.COMMENTED_OUT_KEY_2}"
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
// These env vars are NOT set, but should not cause an error
|
||||||
|
// because they only appear in comment lines
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, config.RequiredAPIKeys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env macros in comments ignored while active ones resolve", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_ACTIVE_KEY", "active-key-value")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
# apiKeys: ["${env.COMMENTED_OUT_KEY}"]
|
||||||
|
apiKeys: ["${env.TEST_ACTIVE_KEY}"]
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"active-key-value"}, config.RequiredAPIKeys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env macros in indented comments are ignored", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: |
|
||||||
|
server
|
||||||
|
--port 8080
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
# metadata:
|
||||||
|
# api_key: "${env.SOME_UNSET_KEY}"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env macros in inline comments are ignored", func(t *testing.T) {
|
||||||
|
t.Setenv("TEST_INLINE_KEY", "real-value")
|
||||||
|
|
||||||
|
content := `
|
||||||
|
apiKeys: ["${env.TEST_INLINE_KEY}"] # TODO: add ${env.FUTURE_KEY} later
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "server"
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"real-value"}, config.RequiredAPIKeys)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_TimeoutsParsing(t *testing.T) {
|
||||||
|
configYaml := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: test-server --port ${PORT}
|
||||||
|
timeouts:
|
||||||
|
connect: 45
|
||||||
|
responseHeader: 120
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
modelConfig, found := config.Models["model1"]
|
||||||
|
require.True(t, found, "model1 should exist in config")
|
||||||
|
|
||||||
|
assert.Equal(t, 45, modelConfig.Timeouts.Connect)
|
||||||
|
assert.Equal(t, 120, modelConfig.Timeouts.ResponseHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_TimeoutsDefaults(t *testing.T) {
|
||||||
|
configYaml := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: test-server --port ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
modelConfig, found := config.Models["model1"]
|
||||||
|
require.True(t, found, "model1 should exist in config")
|
||||||
|
|
||||||
|
// Default values should be set during unmarshaling
|
||||||
|
assert.Equal(t, 30, modelConfig.Timeouts.Connect)
|
||||||
|
assert.Equal(t, 0, modelConfig.Timeouts.ResponseHeader)
|
||||||
|
assert.Equal(t, 10, modelConfig.Timeouts.TLSHandshake)
|
||||||
|
assert.Equal(t, 1, modelConfig.Timeouts.ExpectContinue)
|
||||||
|
assert.Equal(t, 90, modelConfig.Timeouts.IdleConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_TimeoutsZeroAllowed(t *testing.T) {
|
||||||
|
configYaml := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: test-server --port ${PORT}
|
||||||
|
timeouts:
|
||||||
|
connect: 0
|
||||||
|
responseHeader: 0
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
modelConfig, found := config.Models["model1"]
|
||||||
|
require.True(t, found, "model1 should exist in config")
|
||||||
|
|
||||||
|
// Explicit 0 should be preserved (disables timeout)
|
||||||
|
assert.Equal(t, 0, modelConfig.Timeouts.Connect)
|
||||||
|
assert.Equal(t, 0, modelConfig.Timeouts.ResponseHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_PeerTimeoutsParsing(t *testing.T) {
|
||||||
|
configYaml := `
|
||||||
|
peers:
|
||||||
|
peer1:
|
||||||
|
proxy: http://example.com
|
||||||
|
models: [model1]
|
||||||
|
timeouts:
|
||||||
|
connect: 45
|
||||||
|
responseHeader: 120
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peerConfig, found := config.Peers["peer1"]
|
||||||
|
require.True(t, found, "peer1 should exist in config")
|
||||||
|
|
||||||
|
assert.Equal(t, 45, peerConfig.Timeouts.Connect)
|
||||||
|
assert.Equal(t, 120, peerConfig.Timeouts.ResponseHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_PeerTimeoutsDefaults(t *testing.T) {
|
||||||
|
configYaml := `
|
||||||
|
peers:
|
||||||
|
peer1:
|
||||||
|
proxy: http://example.com
|
||||||
|
models: [model1]
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(configYaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peerConfig, found := config.Peers["peer1"]
|
||||||
|
require.True(t, found, "peer1 should exist in config")
|
||||||
|
|
||||||
|
// Default values should be set during unmarshaling
|
||||||
|
assert.Equal(t, 30, peerConfig.Timeouts.Connect)
|
||||||
|
assert.Equal(t, 60, peerConfig.Timeouts.ResponseHeader)
|
||||||
|
assert.Equal(t, 10, peerConfig.Timeouts.TLSHandshake)
|
||||||
|
assert.Equal(t, 1, peerConfig.Timeouts.ExpectContinue)
|
||||||
|
assert.Equal(t, 90, peerConfig.Timeouts.IdleConn)
|
||||||
|
}
|
||||||
|
|||||||
@@ -155,9 +155,19 @@ groups:
|
|||||||
|
|
||||||
modelLoadingState := false
|
modelLoadingState := false
|
||||||
|
|
||||||
|
defaultTimeout := TimeoutsConfig{
|
||||||
|
Connect: 30,
|
||||||
|
KeepAlive: 30,
|
||||||
|
ResponseHeader: 0,
|
||||||
|
TLSHandshake: 10,
|
||||||
|
ExpectContinue: 1,
|
||||||
|
IdleConn: 90,
|
||||||
|
}
|
||||||
|
|
||||||
expected := Config{
|
expected := Config{
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
LogTimeFormat: "",
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
Macros: MacroList{
|
Macros: MacroList{
|
||||||
{"svr-path", "path/to/server"},
|
{"svr-path", "path/to/server"},
|
||||||
@@ -172,6 +182,7 @@ groups:
|
|||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
|
Timeouts: defaultTimeout,
|
||||||
},
|
},
|
||||||
"model2": {
|
"model2": {
|
||||||
Cmd: "path/to/server --arg1 one",
|
Cmd: "path/to/server --arg1 one",
|
||||||
@@ -181,6 +192,7 @@ groups:
|
|||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
|
Timeouts: defaultTimeout,
|
||||||
},
|
},
|
||||||
"model3": {
|
"model3": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -190,6 +202,7 @@ groups:
|
|||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
|
Timeouts: defaultTimeout,
|
||||||
},
|
},
|
||||||
"model4": {
|
"model4": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -199,10 +212,12 @@ groups:
|
|||||||
Aliases: []string{},
|
Aliases: []string{},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
SendLoadingState: &modelLoadingState,
|
SendLoadingState: &modelLoadingState,
|
||||||
|
Timeouts: defaultTimeout,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
MetricsMaxInMemory: 1000,
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
Profiles: map[string][]string{
|
Profiles: map[string][]string{
|
||||||
"test": {"model1", "model2"},
|
"test": {"model1", "model2"},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -0,0 +1,114 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||||
|
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||||
|
var ProtectedParams = []string{"model"}
|
||||||
|
|
||||||
|
// Filters contains filter settings for modifying request parameters
|
||||||
|
// Used by both models and peers
|
||||||
|
type Filters struct {
|
||||||
|
// StripParams is a comma-separated list of parameters to remove from requests
|
||||||
|
// The "model" parameter can never be removed
|
||||||
|
StripParams string `yaml:"stripParams"`
|
||||||
|
|
||||||
|
// SetParams is a dictionary of parameters to set/override in requests
|
||||||
|
// Protected params (like "model") cannot be set
|
||||||
|
SetParams map[string]any `yaml:"setParams"`
|
||||||
|
|
||||||
|
// SetParamsByID maps requested model IDs to parameters to set/override in requests.
|
||||||
|
// Useful with aliases: a single loaded model can behave differently depending on
|
||||||
|
// which alias the client used. Applied after SetParams, so it can override those values.
|
||||||
|
// Protected params (like "model") cannot be set.
|
||||||
|
SetParamsByID map[string]map[string]any `yaml:"setParamsByID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||||
|
// with duplicates, empty strings, and protected params removed
|
||||||
|
func (f Filters) SanitizedStripParams() []string {
|
||||||
|
if f.StripParams == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
params := strings.Split(f.StripParams, ",")
|
||||||
|
cleaned := make([]string, 0, len(params))
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, param := range params {
|
||||||
|
trimmed := strings.TrimSpace(param)
|
||||||
|
// Skip protected params, empty strings, and duplicates
|
||||||
|
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[trimmed] = true
|
||||||
|
cleaned = append(cleaned, trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cleaned) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Sort(cleaned)
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizedSetParamsByID returns the params to set for the given requestedModelID,
|
||||||
|
// with protected params removed and keys sorted for consistent iteration order.
|
||||||
|
// Returns nil if the ID has no entry or all its params are protected.
|
||||||
|
func (f Filters) SanitizedSetParamsByID(requestedModelID string) (map[string]any, []string) {
|
||||||
|
if len(f.SetParamsByID) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
params, found := f.SetParamsByID[requestedModelID]
|
||||||
|
if !found || len(params) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
result := make(map[string]any, len(params))
|
||||||
|
keys := make([]string, 0, len(params))
|
||||||
|
for key, value := range params {
|
||||||
|
if slices.Contains(ProtectedParams, key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result[key] = value
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return result, keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||||
|
// and keys sorted for consistent iteration order
|
||||||
|
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||||
|
if len(f.SetParams) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]any, len(f.SetParams))
|
||||||
|
keys := make([]string, 0, len(f.SetParams))
|
||||||
|
|
||||||
|
for key, value := range f.SetParams {
|
||||||
|
// Skip protected params
|
||||||
|
if slices.Contains(ProtectedParams, key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result[key] = value
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort keys for consistent ordering
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, keys
|
||||||
|
}
|
||||||
@@ -0,0 +1,285 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
stripParams string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
stripParams: "",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single param",
|
||||||
|
stripParams: "temperature",
|
||||||
|
want: []string{"temperature"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple params",
|
||||||
|
stripParams: "temperature, top_p, top_k",
|
||||||
|
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model param filtered",
|
||||||
|
stripParams: "model, temperature, top_p",
|
||||||
|
want: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only model param",
|
||||||
|
stripParams: "model",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicates removed",
|
||||||
|
stripParams: "temperature, top_p, temperature",
|
||||||
|
want: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra whitespace",
|
||||||
|
stripParams: " temperature , top_p ",
|
||||||
|
want: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty values filtered",
|
||||||
|
stripParams: "temperature,,top_p,",
|
||||||
|
want: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
f := Filters{StripParams: tt.stripParams}
|
||||||
|
got := f.SanitizedStripParams()
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setParams map[string]any
|
||||||
|
wantParams map[string]any
|
||||||
|
wantKeys []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty setParams",
|
||||||
|
setParams: nil,
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty map",
|
||||||
|
setParams: map[string]any{},
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal params",
|
||||||
|
setParams: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
wantKeys: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "protected model param filtered",
|
||||||
|
setParams: map[string]any{
|
||||||
|
"model": "should-be-filtered",
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
wantKeys: []string{"temperature"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only protected param",
|
||||||
|
setParams: map[string]any{
|
||||||
|
"model": "should-be-filtered",
|
||||||
|
},
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex nested values",
|
||||||
|
setParams: map[string]any{
|
||||||
|
"provider": map[string]any{
|
||||||
|
"data_collection": "deny",
|
||||||
|
"allow_fallbacks": false,
|
||||||
|
},
|
||||||
|
"transforms": []string{"middle-out"},
|
||||||
|
},
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"provider": map[string]any{
|
||||||
|
"data_collection": "deny",
|
||||||
|
"allow_fallbacks": false,
|
||||||
|
},
|
||||||
|
"transforms": []string{"middle-out"},
|
||||||
|
},
|
||||||
|
wantKeys: []string{"provider", "transforms"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
f := Filters{SetParams: tt.setParams}
|
||||||
|
gotParams, gotKeys := f.SanitizedSetParams()
|
||||||
|
|
||||||
|
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||||
|
for i, key := range gotKeys {
|
||||||
|
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantParams == nil {
|
||||||
|
assert.Nil(t, gotParams, "expected nil params")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||||
|
for key, wantValue := range tt.wantParams {
|
||||||
|
gotValue, exists := gotParams[key]
|
||||||
|
assert.True(t, exists, "missing key: %s", key)
|
||||||
|
// Simple comparison for basic types
|
||||||
|
switch v := wantValue.(type) {
|
||||||
|
case string, int, float64, bool:
|
||||||
|
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilters_SanitizedSetParamsByID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setParamsByID map[string]map[string]any
|
||||||
|
requestedModelID string
|
||||||
|
wantParams map[string]any
|
||||||
|
wantKeys []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty SetParamsByID returns nil",
|
||||||
|
setParamsByID: nil,
|
||||||
|
requestedModelID: "model1",
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty map returns nil",
|
||||||
|
setParamsByID: map[string]map[string]any{},
|
||||||
|
requestedModelID: "model1",
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching model ID returns nil",
|
||||||
|
setParamsByID: map[string]map[string]any{
|
||||||
|
"model2": {"temperature": 0.9},
|
||||||
|
},
|
||||||
|
requestedModelID: "model1",
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "matching model ID returns correct params",
|
||||||
|
setParamsByID: map[string]map[string]any{
|
||||||
|
"model1": {"temperature": 0.7, "top_p": 0.9},
|
||||||
|
"model2": {"temperature": 0.5},
|
||||||
|
},
|
||||||
|
requestedModelID: "model1",
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
wantKeys: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "protected param model is filtered out",
|
||||||
|
setParamsByID: map[string]map[string]any{
|
||||||
|
"model1": {
|
||||||
|
"model": "should-be-filtered",
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModelID: "model1",
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
wantKeys: []string{"temperature"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only protected param returns nil",
|
||||||
|
setParamsByID: map[string]map[string]any{
|
||||||
|
"model1": {
|
||||||
|
"model": "should-be-filtered",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModelID: "model1",
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keys are sorted",
|
||||||
|
setParamsByID: map[string]map[string]any{
|
||||||
|
"model1": {
|
||||||
|
"z_param": "z",
|
||||||
|
"a_param": "a",
|
||||||
|
"m_param": "m",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModelID: "model1",
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"z_param": "z",
|
||||||
|
"a_param": "a",
|
||||||
|
"m_param": "m",
|
||||||
|
},
|
||||||
|
wantKeys: []string{"a_param", "m_param", "z_param"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alias style key lookup",
|
||||||
|
setParamsByID: map[string]map[string]any{
|
||||||
|
"model1:high": {"reasoning_effort": "high"},
|
||||||
|
"model1:low": {"reasoning_effort": "low"},
|
||||||
|
},
|
||||||
|
requestedModelID: "model1:high",
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"reasoning_effort": "high",
|
||||||
|
},
|
||||||
|
wantKeys: []string{"reasoning_effort"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
f := Filters{SetParamsByID: tt.setParamsByID}
|
||||||
|
gotParams, gotKeys := f.SanitizedSetParamsByID(tt.requestedModelID)
|
||||||
|
|
||||||
|
if tt.wantParams == nil {
|
||||||
|
assert.Nil(t, gotParams)
|
||||||
|
assert.Nil(t, gotKeys)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantKeys, gotKeys)
|
||||||
|
assert.Equal(t, tt.wantParams, gotParams)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtectedParams(t *testing.T) {
|
||||||
|
// Verify that "model" is protected
|
||||||
|
assert.Contains(t, ProtectedParams, "model")
|
||||||
|
}
|
||||||
@@ -104,6 +104,62 @@ models:
|
|||||||
assert.Contains(t, err.Error(), "self-reference")
|
assert.Contains(t, err.Error(), "self-reference")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test macro substitution in name and description fields
|
||||||
|
func TestConfig_MacroInNameAndDescription(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"VARIANT": "Q4_K_M"
|
||||||
|
"FAMILY": "llama"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: echo ok
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
name: "${FAMILY} ${VARIANT}"
|
||||||
|
description: "A ${FAMILY} model in ${VARIANT} format"
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "llama Q4_K_M", config.Models["my-model"].Name)
|
||||||
|
assert.Equal(t, "A llama model in Q4_K_M format", config.Models["my-model"].Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MODEL_ID macro in name and description fields
|
||||||
|
func TestConfig_ModelIDInNameAndDescription(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
models:
|
||||||
|
llama-3b:
|
||||||
|
cmd: echo ok
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
name: "Model: ${MODEL_ID}"
|
||||||
|
description: "Running ${MODEL_ID}"
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "Model: llama-3b", config.Models["llama-3b"].Name)
|
||||||
|
assert.Equal(t, "Running llama-3b", config.Models["llama-3b"].Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test unknown macro in name or description returns an error
|
||||||
|
func TestConfig_UnknownMacroInNameDescription(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ok
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
name: "Model ${UNDEFINED}"
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||||
|
}
|
||||||
|
|
||||||
// Test undefined macro reference error
|
// Test undefined macro reference error
|
||||||
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||||
content := `
|
content := `
|
||||||
|
|||||||
@@ -0,0 +1,226 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var varKeyPattern = regexp.MustCompile(`^[a-zA-Z0-9]{1,8}$`)
|
||||||
|
|
||||||
|
// MatrixConfig represents the swap matrix configuration block.
|
||||||
|
type MatrixConfig struct {
|
||||||
|
Var map[string]string `yaml:"vars"`
|
||||||
|
EvictCosts map[string]int `yaml:"evict_costs"`
|
||||||
|
Sets OrderedSets `yaml:"sets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEntry is a single named set with its DSL expression.
|
||||||
|
type SetEntry struct {
|
||||||
|
Name string
|
||||||
|
DSL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderedSets preserves YAML definition order of sets (used for tie-breaking).
|
||||||
|
type OrderedSets []SetEntry
|
||||||
|
|
||||||
|
func (os *OrderedSets) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
if value.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("sets must be a mapping")
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]SetEntry, 0, len(value.Content)/2)
|
||||||
|
for i := 0; i < len(value.Content); i += 2 {
|
||||||
|
keyNode := value.Content[i]
|
||||||
|
valueNode := value.Content[i+1]
|
||||||
|
|
||||||
|
var name string
|
||||||
|
if err := keyNode.Decode(&name); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode set name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dsl string
|
||||||
|
if err := valueNode.Decode(&dsl); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode DSL for set %q: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = append(entries, SetEntry{Name: name, DSL: dsl})
|
||||||
|
}
|
||||||
|
|
||||||
|
*os = entries
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpandedSet is one valid combination of concurrent models (real model names).
|
||||||
|
type ExpandedSet struct {
|
||||||
|
SetName string
|
||||||
|
DSL string
|
||||||
|
Models []string // real model names, sorted
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateMatrix validates the matrix config and returns all expanded sets.
|
||||||
|
func ValidateMatrix(matrix MatrixConfig, models map[string]ModelConfig) ([]ExpandedSet, error) {
|
||||||
|
if len(matrix.Sets) == 0 {
|
||||||
|
return nil, fmt.Errorf("matrix must define at least one set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matrix.Var) == 0 {
|
||||||
|
return nil, fmt.Errorf("matrix must define at least one var")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate var entries
|
||||||
|
if matrix.Var != nil {
|
||||||
|
for id, modelName := range matrix.Var {
|
||||||
|
if !varKeyPattern.MatchString(id) {
|
||||||
|
return nil, fmt.Errorf("var key %q must be alphanumeric and 1-8 characters", id)
|
||||||
|
}
|
||||||
|
if _, exists := models[modelName]; !exists {
|
||||||
|
return nil, fmt.Errorf("var key %q references unknown model %q", id, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate evict_costs
|
||||||
|
if matrix.EvictCosts != nil {
|
||||||
|
for key, cost := range matrix.EvictCosts {
|
||||||
|
if cost <= 0 {
|
||||||
|
return nil, fmt.Errorf("evict_cost for %q must be a positive integer, got %d", key, cost)
|
||||||
|
}
|
||||||
|
if _, ok := matrix.Var[key]; !ok {
|
||||||
|
return nil, fmt.Errorf("evict_costs: unknown var ID %q", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build dependency graph for +ref topological sort
|
||||||
|
setNames := make(map[string]bool)
|
||||||
|
for _, entry := range matrix.Sets {
|
||||||
|
setNames[entry.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
deps := make(map[string][]string) // setName -> set names it depends on
|
||||||
|
for _, entry := range matrix.Sets {
|
||||||
|
refs, err := extractRefs(entry.DSL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("set %q: %w", entry.Name, err)
|
||||||
|
}
|
||||||
|
for _, ref := range refs {
|
||||||
|
if !setNames[ref] {
|
||||||
|
return nil, fmt.Errorf("set %q references undefined set %q", entry.Name, ref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deps[entry.Name] = refs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topological sort with cycle detection
|
||||||
|
order, err := topologicalSort(matrix.Sets, deps)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand sets in topological order
|
||||||
|
resolvedRefs := make(map[string][][]string) // set name -> expanded alias-level combos
|
||||||
|
var allExpanded []ExpandedSet
|
||||||
|
totalCombinations := 0
|
||||||
|
|
||||||
|
// Build ordered map for efficient lookup
|
||||||
|
setDSL := make(map[string]string)
|
||||||
|
for _, entry := range matrix.Sets {
|
||||||
|
setDSL[entry.Name] = entry.DSL
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range order {
|
||||||
|
dsl := setDSL[name]
|
||||||
|
combos, err := ParseAndExpandDSL(dsl, resolvedRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("set %q: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedRefs[name] = combos
|
||||||
|
|
||||||
|
// Resolve var IDs to real model names
|
||||||
|
for _, combo := range combos {
|
||||||
|
resolved := make([]string, len(combo))
|
||||||
|
for i, ident := range combo {
|
||||||
|
realName, ok := matrix.Var[ident]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("set %q: unknown var ID %q", name, ident)
|
||||||
|
}
|
||||||
|
resolved[i] = realName
|
||||||
|
}
|
||||||
|
sort.Strings(resolved)
|
||||||
|
allExpanded = append(allExpanded, ExpandedSet{
|
||||||
|
SetName: name,
|
||||||
|
DSL: dsl,
|
||||||
|
Models: resolved,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
totalCombinations += len(combos)
|
||||||
|
if totalCombinations > maxDSLExpansions {
|
||||||
|
return nil, fmt.Errorf("total expanded combinations (%d) exceed limit of %d", totalCombinations, maxDSLExpansions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allExpanded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// topologicalSort returns set names in dependency order.
|
||||||
|
// Returns an error if a cycle is detected.
|
||||||
|
func topologicalSort(sets OrderedSets, deps map[string][]string) ([]string, error) {
|
||||||
|
// States: 0 = unvisited, 1 = visiting, 2 = visited
|
||||||
|
state := make(map[string]int)
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
var visit func(name string) error
|
||||||
|
visit = func(name string) error {
|
||||||
|
switch state[name] {
|
||||||
|
case 1:
|
||||||
|
return fmt.Errorf("circular reference detected involving set %q", name)
|
||||||
|
case 2:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state[name] = 1
|
||||||
|
|
||||||
|
for _, dep := range deps[name] {
|
||||||
|
if err := visit(dep); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state[name] = 2
|
||||||
|
order = append(order, name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visit in definition order for deterministic output
|
||||||
|
for _, entry := range sets {
|
||||||
|
if state[entry.Name] == 0 {
|
||||||
|
if err := visit(entry.Name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return order, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolvedEvictCosts returns a map of real model name -> evict cost,
|
||||||
|
// resolving var IDs. Models not listed default to 1.
|
||||||
|
func (m *MatrixConfig) ResolvedEvictCosts() map[string]int {
|
||||||
|
costs := make(map[string]int)
|
||||||
|
if m.EvictCosts == nil {
|
||||||
|
return costs
|
||||||
|
}
|
||||||
|
for key, cost := range m.EvictCosts {
|
||||||
|
// Resolve var ID if present
|
||||||
|
if realName, ok := m.Var[key]; ok {
|
||||||
|
costs[realName] = cost
|
||||||
|
} else {
|
||||||
|
costs[key] = cost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return costs
|
||||||
|
}
|
||||||
@@ -0,0 +1,376 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxDSLExpansions = 1000
|
||||||
|
|
||||||
|
// Token types for the DSL lexer
|
||||||
|
type tokenType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
tokIdent tokenType = iota // model alias or name
|
||||||
|
tokAnd // &
|
||||||
|
tokOr // |
|
||||||
|
tokLParen // (
|
||||||
|
tokRParen // )
|
||||||
|
tokRef // +setName
|
||||||
|
tokEOF
|
||||||
|
)
|
||||||
|
|
||||||
|
type token struct {
|
||||||
|
typ tokenType
|
||||||
|
val string
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenize splits a DSL string into tokens.
|
||||||
|
func tokenize(input string) ([]token, error) {
|
||||||
|
var tokens []token
|
||||||
|
i := 0
|
||||||
|
runes := []rune(input)
|
||||||
|
|
||||||
|
for i < len(runes) {
|
||||||
|
ch := runes[i]
|
||||||
|
|
||||||
|
// skip whitespace
|
||||||
|
if unicode.IsSpace(ch) {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ch {
|
||||||
|
case '&':
|
||||||
|
tokens = append(tokens, token{tokAnd, "&"})
|
||||||
|
i++
|
||||||
|
case '|':
|
||||||
|
tokens = append(tokens, token{tokOr, "|"})
|
||||||
|
i++
|
||||||
|
case '(':
|
||||||
|
tokens = append(tokens, token{tokLParen, "("})
|
||||||
|
i++
|
||||||
|
case ')':
|
||||||
|
tokens = append(tokens, token{tokRParen, ")"})
|
||||||
|
i++
|
||||||
|
case '+':
|
||||||
|
// +ref: read the identifier that follows
|
||||||
|
i++
|
||||||
|
start := i
|
||||||
|
for i < len(runes) && isIdentChar(runes[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i == start {
|
||||||
|
return nil, fmt.Errorf("expected set name after '+' at position %d", start)
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token{tokRef, string(runes[start:i])})
|
||||||
|
default:
|
||||||
|
if isIdentChar(ch) {
|
||||||
|
start := i
|
||||||
|
for i < len(runes) && isIdentChar(runes[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token{tokIdent, string(runes[start:i])})
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unexpected character %q at position %d", ch, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = append(tokens, token{tokEOF, ""})
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdentChar(ch rune) bool {
|
||||||
|
return unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '-' || ch == '.'
|
||||||
|
}
|
||||||
|
|
||||||
|
// AST node types
|
||||||
|
type dslNode interface {
|
||||||
|
dslNode()
|
||||||
|
}
|
||||||
|
|
||||||
|
type andNode struct {
|
||||||
|
children []dslNode
|
||||||
|
}
|
||||||
|
|
||||||
|
type orNode struct {
|
||||||
|
children []dslNode
|
||||||
|
}
|
||||||
|
|
||||||
|
type leafNode struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type refNode struct {
|
||||||
|
setName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (andNode) dslNode() {}
|
||||||
|
func (orNode) dslNode() {}
|
||||||
|
func (leafNode) dslNode() {}
|
||||||
|
func (refNode) dslNode() {}
|
||||||
|
|
||||||
|
// parser holds state for recursive-descent parsing.
|
||||||
|
type parser struct {
|
||||||
|
tokens []token
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) peek() token {
|
||||||
|
if p.pos < len(p.tokens) {
|
||||||
|
return p.tokens[p.pos]
|
||||||
|
}
|
||||||
|
return token{tokEOF, ""}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) next() token {
|
||||||
|
t := p.peek()
|
||||||
|
if t.typ != tokEOF {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) expect(typ tokenType) (token, error) {
|
||||||
|
t := p.next()
|
||||||
|
if t.typ != typ {
|
||||||
|
return t, fmt.Errorf("expected token type %d, got %q", typ, t.val)
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grammar:
|
||||||
|
//
|
||||||
|
// expr = andExpr
|
||||||
|
// andExpr = orExpr ('&' orExpr)*
|
||||||
|
// orExpr = atom ('|' atom)*
|
||||||
|
// atom = ident | '+' ident | '(' expr ')'
|
||||||
|
//
|
||||||
|
// & binds tighter than |, so "a | b & c" means "a | (b & c)"
|
||||||
|
func parse(tokens []token) (dslNode, error) {
|
||||||
|
p := &parser{tokens: tokens}
|
||||||
|
node, err := p.parseExpr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if p.peek().typ != tokEOF {
|
||||||
|
return nil, fmt.Errorf("unexpected token %q after expression", p.peek().val)
|
||||||
|
}
|
||||||
|
return node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseExpr() (dslNode, error) {
|
||||||
|
return p.parseOrExpr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseOrExpr() (dslNode, error) {
|
||||||
|
left, err := p.parseAndExpr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.peek().typ == tokOr {
|
||||||
|
children := []dslNode{left}
|
||||||
|
for p.peek().typ == tokOr {
|
||||||
|
p.next() // consume |
|
||||||
|
right, err := p.parseAndExpr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
children = append(children, right)
|
||||||
|
}
|
||||||
|
return orNode{children: children}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseAndExpr() (dslNode, error) {
|
||||||
|
left, err := p.parseAtom()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.peek().typ == tokAnd {
|
||||||
|
children := []dslNode{left}
|
||||||
|
for p.peek().typ == tokAnd {
|
||||||
|
p.next() // consume &
|
||||||
|
right, err := p.parseAtom()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
children = append(children, right)
|
||||||
|
}
|
||||||
|
return andNode{children: children}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *parser) parseAtom() (dslNode, error) {
|
||||||
|
t := p.peek()
|
||||||
|
|
||||||
|
switch t.typ {
|
||||||
|
case tokIdent:
|
||||||
|
p.next()
|
||||||
|
return leafNode{name: t.val}, nil
|
||||||
|
|
||||||
|
case tokRef:
|
||||||
|
p.next()
|
||||||
|
return refNode{setName: t.val}, nil
|
||||||
|
|
||||||
|
case tokLParen:
|
||||||
|
p.next() // consume (
|
||||||
|
node, err := p.parseExpr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := p.expect(tokRParen); err != nil {
|
||||||
|
return nil, fmt.Errorf("missing closing parenthesis")
|
||||||
|
}
|
||||||
|
return node, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected token %q", t.val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// expand walks the AST and produces all combinations.
|
||||||
|
// resolvedRefs contains previously expanded sets for +ref resolution.
|
||||||
|
func expand(node dslNode, resolvedRefs map[string][][]string) ([][]string, error) {
|
||||||
|
switch n := node.(type) {
|
||||||
|
case leafNode:
|
||||||
|
return [][]string{{n.name}}, nil
|
||||||
|
|
||||||
|
case refNode:
|
||||||
|
expanded, ok := resolvedRefs[n.setName]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown set reference +%s", n.setName)
|
||||||
|
}
|
||||||
|
// Return a copy
|
||||||
|
result := make([][]string, len(expanded))
|
||||||
|
for i, combo := range expanded {
|
||||||
|
result[i] = make([]string, len(combo))
|
||||||
|
copy(result[i], combo)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
|
||||||
|
case orNode:
|
||||||
|
// Union of all children's expansions
|
||||||
|
var result [][]string
|
||||||
|
for _, child := range n.children {
|
||||||
|
childResult, err := expand(child, resolvedRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, childResult...)
|
||||||
|
if len(result) > maxDSLExpansions {
|
||||||
|
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
|
||||||
|
case andNode:
|
||||||
|
// Cartesian product across children
|
||||||
|
result := [][]string{{}} // start with one empty combo
|
||||||
|
for _, child := range n.children {
|
||||||
|
childResult, err := expand(child, resolvedRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result, err = cartesianProduct(result, childResult, maxDSLExpansions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown node type %T", node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cartesianProduct computes the cartesian product of two sets of combinations.
|
||||||
|
// It returns an error if the product would exceed cap.
|
||||||
|
func cartesianProduct(left, right [][]string, cap int) ([][]string, error) {
|
||||||
|
if int64(len(left))*int64(len(right)) > int64(cap) {
|
||||||
|
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", cap)
|
||||||
|
}
|
||||||
|
result := make([][]string, 0, len(left)*len(right))
|
||||||
|
for _, l := range left {
|
||||||
|
for _, r := range right {
|
||||||
|
combo := make([]string, 0, len(l)+len(r))
|
||||||
|
combo = append(combo, l...)
|
||||||
|
combo = append(combo, r...)
|
||||||
|
result = append(result, combo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseAndExpandDSL tokenizes, parses, and expands a DSL string.
|
||||||
|
// resolvedRefs contains previously expanded sets for +ref inlining.
|
||||||
|
func ParseAndExpandDSL(dsl string, resolvedRefs map[string][][]string) ([][]string, error) {
|
||||||
|
dsl = strings.TrimSpace(dsl)
|
||||||
|
if dsl == "" {
|
||||||
|
return nil, fmt.Errorf("empty DSL expression")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := tokenize(dsl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tokenize: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tree, err := parse(tokens)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := expand(tree, resolvedRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplicate models within each combination and sort for consistency
|
||||||
|
for i, combo := range result {
|
||||||
|
result[i] = dedupAndSort(combo)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dedupAndSort removes duplicate entries and sorts alphabetically.
|
||||||
|
func dedupAndSort(items []string) []string {
|
||||||
|
seen := make(map[string]bool, len(items))
|
||||||
|
var unique []string
|
||||||
|
for _, item := range items {
|
||||||
|
if !seen[item] {
|
||||||
|
seen[item] = true
|
||||||
|
unique = append(unique, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(unique)
|
||||||
|
return unique
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRefs scans a DSL string for +ref tokens without full parsing.
|
||||||
|
// Used for building the dependency graph for topological sorting.
|
||||||
|
func extractRefs(dsl string) ([]string, error) {
|
||||||
|
tokens, err := tokenize(dsl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var refs []string
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, t := range tokens {
|
||||||
|
if t.typ == tokRef && !seen[t.val] {
|
||||||
|
seen[t.val] = true
|
||||||
|
refs = append(refs, t.val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return refs, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,300 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDSL_Tokenize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expect []token
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single identifier",
|
||||||
|
input: "abc",
|
||||||
|
expect: []token{
|
||||||
|
{tokIdent, "abc"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "identifier with hyphens and dots",
|
||||||
|
input: "model-name.v2",
|
||||||
|
expect: []token{
|
||||||
|
{tokIdent, "model-name.v2"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "and expression",
|
||||||
|
input: "a & b",
|
||||||
|
expect: []token{
|
||||||
|
{tokIdent, "a"},
|
||||||
|
{tokAnd, "&"},
|
||||||
|
{tokIdent, "b"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "or expression",
|
||||||
|
input: "a | b",
|
||||||
|
expect: []token{
|
||||||
|
{tokIdent, "a"},
|
||||||
|
{tokOr, "|"},
|
||||||
|
{tokIdent, "b"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parentheses",
|
||||||
|
input: "(a | b) & c",
|
||||||
|
expect: []token{
|
||||||
|
{tokLParen, "("},
|
||||||
|
{tokIdent, "a"},
|
||||||
|
{tokOr, "|"},
|
||||||
|
{tokIdent, "b"},
|
||||||
|
{tokRParen, ")"},
|
||||||
|
{tokAnd, "&"},
|
||||||
|
{tokIdent, "c"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ref token",
|
||||||
|
input: "+llms & v",
|
||||||
|
expect: []token{
|
||||||
|
{tokRef, "llms"},
|
||||||
|
{tokAnd, "&"},
|
||||||
|
{tokIdent, "v"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no whitespace",
|
||||||
|
input: "(a|b)&c",
|
||||||
|
expect: []token{
|
||||||
|
{tokLParen, "("},
|
||||||
|
{tokIdent, "a"},
|
||||||
|
{tokOr, "|"},
|
||||||
|
{tokIdent, "b"},
|
||||||
|
{tokRParen, ")"},
|
||||||
|
{tokAnd, "&"},
|
||||||
|
{tokIdent, "c"},
|
||||||
|
{tokEOF, ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty ref",
|
||||||
|
input: "+",
|
||||||
|
errMsg: "expected set name after '+'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid character",
|
||||||
|
input: "a @ b",
|
||||||
|
errMsg: "unexpected character",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tokens, err := tokenize(tt.input)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expect, tokens)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDSL_ParseAndExpand(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dsl string
|
||||||
|
refs map[string][][]string
|
||||||
|
expect [][]string
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single model",
|
||||||
|
dsl: "L",
|
||||||
|
expect: [][]string{{"L"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two models with AND",
|
||||||
|
dsl: "a & b",
|
||||||
|
expect: [][]string{{"a", "b"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two models with OR",
|
||||||
|
dsl: "a | b",
|
||||||
|
expect: [][]string{{"a"}, {"b"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three models with OR",
|
||||||
|
dsl: "a | b | c",
|
||||||
|
expect: [][]string{{"a"}, {"b"}, {"c"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cartesian product (a|b) & (c|d)",
|
||||||
|
dsl: "(a | b) & (c | d)",
|
||||||
|
expect: [][]string{
|
||||||
|
{"a", "c"},
|
||||||
|
{"a", "d"},
|
||||||
|
{"b", "c"},
|
||||||
|
{"b", "d"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three-way AND",
|
||||||
|
dsl: "a & b & c",
|
||||||
|
expect: [][]string{
|
||||||
|
{"a", "b", "c"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "(g | q | m) & v",
|
||||||
|
dsl: "(g | q | m) & v",
|
||||||
|
expect: [][]string{
|
||||||
|
{"g", "v"},
|
||||||
|
{"q", "v"},
|
||||||
|
{"m", "v"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "(g | q) & v & e",
|
||||||
|
dsl: "(g | q) & v & e",
|
||||||
|
expect: [][]string{
|
||||||
|
{"e", "g", "v"},
|
||||||
|
{"e", "q", "v"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "precedence: a | b & c means a | (b & c)",
|
||||||
|
dsl: "a | b & c",
|
||||||
|
expect: [][]string{
|
||||||
|
{"a"},
|
||||||
|
{"b", "c"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "+ref inlining",
|
||||||
|
dsl: "+llms & v",
|
||||||
|
refs: map[string][][]string{
|
||||||
|
"llms": {{"g"}, {"q"}, {"m"}},
|
||||||
|
},
|
||||||
|
expect: [][]string{
|
||||||
|
{"g", "v"},
|
||||||
|
{"q", "v"},
|
||||||
|
{"m", "v"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "+ref chained",
|
||||||
|
dsl: "+with_tts & e",
|
||||||
|
refs: map[string][][]string{
|
||||||
|
"with_tts": {{"g", "v"}, {"q", "v"}, {"m", "v"}},
|
||||||
|
},
|
||||||
|
expect: [][]string{
|
||||||
|
{"e", "g", "v"},
|
||||||
|
{"e", "q", "v"},
|
||||||
|
{"e", "m", "v"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dedup within combination",
|
||||||
|
dsl: "a & a",
|
||||||
|
expect: [][]string{
|
||||||
|
{"a"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty expression",
|
||||||
|
dsl: "",
|
||||||
|
errMsg: "empty DSL expression",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unmatched open paren",
|
||||||
|
dsl: "(a | b",
|
||||||
|
errMsg: "missing closing parenthesis",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unmatched close paren",
|
||||||
|
dsl: "a | b)",
|
||||||
|
errMsg: "unexpected token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown ref",
|
||||||
|
dsl: "+unknown",
|
||||||
|
errMsg: "unknown set reference +unknown",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty parens",
|
||||||
|
dsl: "()",
|
||||||
|
errMsg: "unexpected token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
refs := tt.refs
|
||||||
|
if refs == nil {
|
||||||
|
refs = map[string][][]string{}
|
||||||
|
}
|
||||||
|
result, err := ParseAndExpandDSL(tt.dsl, refs)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expect, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDSL_ExpansionCap(t *testing.T) {
|
||||||
|
// Build an expression that would exceed 1000 combinations:
|
||||||
|
// (a1|a2|...|a32) & (b1|b2|...|b32) = 1024 combos
|
||||||
|
var aItems, bItems []string
|
||||||
|
for i := 0; i < 32; i++ {
|
||||||
|
aItems = append(aItems, fmt.Sprintf("a%d", i))
|
||||||
|
bItems = append(bItems, fmt.Sprintf("b%d", i))
|
||||||
|
}
|
||||||
|
dsl := fmt.Sprintf("(%s) & (%s)",
|
||||||
|
join(aItems, " | "),
|
||||||
|
join(bItems, " | "),
|
||||||
|
)
|
||||||
|
_, err := ParseAndExpandDSL(dsl, map[string][][]string{})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDSL_ExtractRefs(t *testing.T) {
|
||||||
|
refs, err := extractRefs("+llms & v & +other")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"llms", "other"}, refs)
|
||||||
|
|
||||||
|
refs, err = extractRefs("a & b")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, refs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func join(items []string, sep string) string {
|
||||||
|
result := ""
|
||||||
|
for i, item := range items {
|
||||||
|
if i > 0 {
|
||||||
|
result += sep
|
||||||
|
}
|
||||||
|
result += item
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -0,0 +1,305 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeModels(names ...string) map[string]ModelConfig {
|
||||||
|
m := make(map[string]ModelConfig)
|
||||||
|
for _, name := range names {
|
||||||
|
m[name] = ModelConfig{Cmd: "echo " + name}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_Basic(t *testing.T) {
|
||||||
|
models := makeModels("gemma", "qwen", "mistral", "voxtral", "llama70B")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{
|
||||||
|
"g": "gemma",
|
||||||
|
"q": "qwen",
|
||||||
|
"m": "mistral",
|
||||||
|
"v": "voxtral",
|
||||||
|
"L": "llama70B",
|
||||||
|
},
|
||||||
|
EvictCosts: map[string]int{
|
||||||
|
"L": 30,
|
||||||
|
"v": 50,
|
||||||
|
},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "standard", DSL: "(g | q | m) & v"},
|
||||||
|
{Name: "full", DSL: "L"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expanded, err := ValidateMatrix(matrix, models)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// standard expands to [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
|
||||||
|
// full expands to [llama70B]
|
||||||
|
assert.Len(t, expanded, 4)
|
||||||
|
|
||||||
|
assert.Equal(t, "standard", expanded[0].SetName)
|
||||||
|
assert.Equal(t, []string{"gemma", "voxtral"}, expanded[0].Models)
|
||||||
|
|
||||||
|
assert.Equal(t, "standard", expanded[1].SetName)
|
||||||
|
assert.Equal(t, []string{"qwen", "voxtral"}, expanded[1].Models)
|
||||||
|
|
||||||
|
assert.Equal(t, "standard", expanded[2].SetName)
|
||||||
|
assert.Equal(t, []string{"mistral", "voxtral"}, expanded[2].Models)
|
||||||
|
|
||||||
|
assert.Equal(t, "full", expanded[3].SetName)
|
||||||
|
assert.Equal(t, []string{"llama70B"}, expanded[3].Models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_WithRef(t *testing.T) {
|
||||||
|
models := makeModels("gemma", "qwen", "mistral", "voxtral", "reranker")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{
|
||||||
|
"g": "gemma",
|
||||||
|
"q": "qwen",
|
||||||
|
"m": "mistral",
|
||||||
|
"v": "voxtral",
|
||||||
|
"e": "reranker",
|
||||||
|
},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "llms", DSL: "g | q | m"},
|
||||||
|
{Name: "with_tts", DSL: "+llms & v"},
|
||||||
|
{Name: "mega", DSL: "+with_tts & e"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expanded, err := ValidateMatrix(matrix, models)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// llms: [gemma], [qwen], [mistral]
|
||||||
|
// with_tts: [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
|
||||||
|
// mega: [gemma,reranker,voxtral], [qwen,reranker,voxtral], [mistral,reranker,voxtral]
|
||||||
|
assert.Len(t, expanded, 9)
|
||||||
|
|
||||||
|
// Check mega entries
|
||||||
|
megaEntries := filterBySetName(expanded, "mega")
|
||||||
|
assert.Len(t, megaEntries, 3)
|
||||||
|
assert.Equal(t, []string{"gemma", "reranker", "voxtral"}, megaEntries[0].Models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_MapIDRequired(t *testing.T) {
|
||||||
|
// DSL cannot use real model names directly — must use var IDs
|
||||||
|
models := makeModels("gemma", "voxtral")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "combo", DSL: "g & voxtral"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown var ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_InvalidAliasKey(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
alias string
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{"too long", "abcdefghi", "alphanumeric and 1-8 characters"},
|
||||||
|
{"has underscore", "a_b", "alphanumeric and 1-8 characters"},
|
||||||
|
{"has hyphen", "a-b", "alphanumeric and 1-8 characters"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{tt.alias: "gemma"},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: tt.alias}},
|
||||||
|
}
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_AliasReferencesUnknownModel(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"x": "nonexistent"},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: "x"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_EvictCostInvalid(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
t.Run("zero cost", func(t *testing.T) {
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
EvictCosts: map[string]int{"g": 0},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||||
|
}
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "positive integer")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negative cost", func(t *testing.T) {
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
EvictCosts: map[string]int{"g": -1},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||||
|
}
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "positive integer")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown var ID in evict_costs", func(t *testing.T) {
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
EvictCosts: map[string]int{"unknown": 5},
|
||||||
|
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||||
|
}
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown var ID")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_CycleDetection(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "a", DSL: "+b"},
|
||||||
|
{Name: "b", DSL: "+a"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "circular reference")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_UndefinedRefTarget(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "a", DSL: "+nonexistent"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "references undefined set")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_NoSets(t *testing.T) {
|
||||||
|
_, err := ValidateMatrix(MatrixConfig{}, makeModels("gemma"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "at least one set")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_UnknownMapIDInDSL(t *testing.T) {
|
||||||
|
models := makeModels("gemma")
|
||||||
|
|
||||||
|
matrix := MatrixConfig{
|
||||||
|
Var: map[string]string{"g": "gemma"},
|
||||||
|
Sets: OrderedSets{
|
||||||
|
{Name: "s", DSL: "g & nonexistent"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ValidateMatrix(matrix, models)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown var ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_ResolvedEvictCosts(t *testing.T) {
|
||||||
|
mc := &MatrixConfig{
|
||||||
|
Var: map[string]string{
|
||||||
|
"g": "gemma",
|
||||||
|
"L": "llama70B",
|
||||||
|
},
|
||||||
|
EvictCosts: map[string]int{
|
||||||
|
"L": 30,
|
||||||
|
"g": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
costs := mc.ResolvedEvictCosts()
|
||||||
|
assert.Equal(t, 30, costs["llama70B"])
|
||||||
|
assert.Equal(t, 5, costs["gemma"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_ConfigXOR(t *testing.T) {
|
||||||
|
// groups and matrix both defined
|
||||||
|
yaml := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: echo model1
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
members:
|
||||||
|
- model1
|
||||||
|
matrix:
|
||||||
|
sets:
|
||||||
|
s: "model1"
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "cannot use both")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMatrix_ConfigMatrixOnly(t *testing.T) {
|
||||||
|
yaml := `
|
||||||
|
models:
|
||||||
|
gemma:
|
||||||
|
cmd: echo gemma
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
qwen:
|
||||||
|
cmd: echo qwen
|
||||||
|
proxy: http://localhost:8081
|
||||||
|
matrix:
|
||||||
|
vars:
|
||||||
|
g: gemma
|
||||||
|
q: qwen
|
||||||
|
sets:
|
||||||
|
combo: "g | q"
|
||||||
|
`
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, cfg.Matrix)
|
||||||
|
assert.Len(t, cfg.ExpandedSets, 2)
|
||||||
|
// Groups should be empty when matrix is used
|
||||||
|
assert.Empty(t, cfg.Groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterBySetName(sets []ExpandedSet, name string) []ExpandedSet {
|
||||||
|
var result []ExpandedSet
|
||||||
|
for _, s := range sets {
|
||||||
|
if s.SetName == name {
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -3,10 +3,23 @@ package config
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
// TimeoutsConfig holds timeout settings for proxy connections
|
||||||
|
// 0 = no timeout
|
||||||
|
type TimeoutsConfig struct {
|
||||||
|
Connect int `yaml:"connect"`
|
||||||
|
KeepAlive int `yaml:"keepalive"`
|
||||||
|
ResponseHeader int `yaml:"responseHeader"`
|
||||||
|
TLSHandshake int `yaml:"tlsHandshake"`
|
||||||
|
ExpectContinue int `yaml:"expectContinue"`
|
||||||
|
IdleConn int `yaml:"idleConn"`
|
||||||
|
}
|
||||||
|
|
||||||
type ModelConfig struct {
|
type ModelConfig struct {
|
||||||
Cmd string `yaml:"cmd"`
|
Cmd string `yaml:"cmd"`
|
||||||
CmdStop string `yaml:"cmdStop"`
|
CmdStop string `yaml:"cmdStop"`
|
||||||
@@ -38,6 +51,9 @@ type ModelConfig struct {
|
|||||||
|
|
||||||
// override global setting
|
// override global setting
|
||||||
SendLoadingState *bool `yaml:"sendLoadingState"`
|
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||||
|
|
||||||
|
// Timeout settings for proxy connections
|
||||||
|
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
@@ -49,12 +65,22 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|||||||
Aliases: []string{},
|
Aliases: []string{},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
UnloadAfter: 0,
|
UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL
|
||||||
Unlisted: false,
|
Unlisted: false,
|
||||||
UseModelName: "",
|
UseModelName: "",
|
||||||
ConcurrencyLimit: 0,
|
ConcurrencyLimit: 0,
|
||||||
Name: "",
|
Name: "",
|
||||||
Description: "",
|
Description: "",
|
||||||
|
|
||||||
|
// matches http.DefaultTransport
|
||||||
|
Timeouts: TimeoutsConfig{
|
||||||
|
Connect: 30,
|
||||||
|
KeepAlive: 30,
|
||||||
|
ResponseHeader: 0,
|
||||||
|
TLSHandshake: 10,
|
||||||
|
ExpectContinue: 1,
|
||||||
|
IdleConn: 90,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||||
@@ -74,16 +100,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|||||||
return SanitizeCommand(m.Cmd)
|
return SanitizeCommand(m.Cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelFilters see issue #174
|
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||||
|
// See issue #174
|
||||||
type ModelFilters struct {
|
type ModelFilters struct {
|
||||||
StripParams string `yaml:"stripParams"`
|
Filters `yaml:",inline"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
type rawModelFilters ModelFilters
|
type rawModelFilters ModelFilters
|
||||||
defaults := rawModelFilters{
|
defaults := rawModelFilters{}
|
||||||
StripParams: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unmarshal(&defaults); err != nil {
|
if err := unmarshal(&defaults); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -104,25 +129,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||||
|
// Returns ([]string, error) to match existing API
|
||||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||||
if f.StripParams == "" {
|
return f.Filters.SanitizedStripParams(), nil
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
params := strings.Split(f.StripParams, ",")
|
|
||||||
cleaned := make([]string, 0, len(params))
|
|
||||||
seen := make(map[string]bool)
|
|
||||||
|
|
||||||
for _, param := range params {
|
|
||||||
trimmed := strings.TrimSpace(param)
|
|
||||||
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[trimmed] = true
|
|
||||||
cleaned = append(cleaned, trimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort cleaned
|
|
||||||
slices.Sort(cleaned)
|
|
||||||
return cleaned, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,3 +72,101 @@ models:
|
|||||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfig_SetParamsByIDAutoAlias(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
setParamsByID:
|
||||||
|
"${MODEL_ID}:high":
|
||||||
|
reasoning_effort: high
|
||||||
|
"${MODEL_ID}:low":
|
||||||
|
reasoning_effort: low
|
||||||
|
`
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Keys (other than the model's own ID) should be registered as aliases
|
||||||
|
realName, found := cfg.RealModelName("model1:high")
|
||||||
|
assert.True(t, found, "model1:high should be an auto-registered alias")
|
||||||
|
assert.Equal(t, "model1", realName)
|
||||||
|
|
||||||
|
realName, found = cfg.RealModelName("model1:low")
|
||||||
|
assert.True(t, found, "model1:low should be an auto-registered alias")
|
||||||
|
assert.Equal(t, "model1", realName)
|
||||||
|
|
||||||
|
// Auto-aliases should also appear in modelConfig.Aliases
|
||||||
|
aliases := cfg.Models["model1"].Aliases
|
||||||
|
assert.Contains(t, aliases, "model1:high")
|
||||||
|
assert.Contains(t, aliases, "model1:low")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_SetParamsByIDAutoAliasConflictWithModelID(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
setParamsByID:
|
||||||
|
model2:
|
||||||
|
reasoning_effort: high
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.ErrorContains(t, err, "conflicts with an existing model ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_SetParamsByIDAutoAliasConflictWithOtherModel(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
setParamsByID:
|
||||||
|
"shared-alias":
|
||||||
|
reasoning_effort: high
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
setParamsByID:
|
||||||
|
"shared-alias":
|
||||||
|
reasoning_effort: low
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.ErrorContains(t, err, "duplicate alias")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
stripParams: "top_k"
|
||||||
|
setParams:
|
||||||
|
temperature: 0.7
|
||||||
|
top_p: 0.9
|
||||||
|
stop:
|
||||||
|
- "<|end|>"
|
||||||
|
- "<|stop|>"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
modelConfig := config.Models["model1"]
|
||||||
|
|
||||||
|
// Check stripParams
|
||||||
|
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||||
|
|
||||||
|
// Check setParams
|
||||||
|
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||||
|
assert.NotNil(t, setParams)
|
||||||
|
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||||
|
assert.Equal(t, 0.7, setParams["temperature"])
|
||||||
|
assert.Equal(t, 0.9, setParams["top_p"])
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,63 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PeerDictionaryConfig map[string]PeerConfig
|
||||||
|
type PeerConfig struct {
|
||||||
|
Proxy string `yaml:"proxy"`
|
||||||
|
ProxyURL *url.URL `yaml:"-"`
|
||||||
|
ApiKey string `yaml:"apiKey"`
|
||||||
|
Models []string `yaml:"models"`
|
||||||
|
Filters Filters `yaml:"filters"`
|
||||||
|
|
||||||
|
// Timeout settings for proxy connections
|
||||||
|
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawPeerConfig PeerConfig
|
||||||
|
defaults := rawPeerConfig{
|
||||||
|
Proxy: "",
|
||||||
|
ApiKey: "",
|
||||||
|
Models: []string{},
|
||||||
|
Filters: Filters{},
|
||||||
|
|
||||||
|
// mostly matches http.DefaultTransport but with a 60s ResponseHeader timeout
|
||||||
|
// to match the pre PR #619 functionality
|
||||||
|
Timeouts: TimeoutsConfig{
|
||||||
|
Connect: 30,
|
||||||
|
KeepAlive: 30,
|
||||||
|
ResponseHeader: 60,
|
||||||
|
TLSHandshake: 10,
|
||||||
|
ExpectContinue: 1,
|
||||||
|
IdleConn: 90,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate proxy is not empty
|
||||||
|
if defaults.Proxy == "" {
|
||||||
|
return fmt.Errorf("proxy is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate proxy is a valid URL and store the parsed value
|
||||||
|
parsedURL, err := url.Parse(defaults.Proxy)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||||
|
}
|
||||||
|
defaults.ProxyURL = parsedURL
|
||||||
|
|
||||||
|
// Validate models is not empty
|
||||||
|
if len(defaults.Models) == 0 {
|
||||||
|
return fmt.Errorf("peer models can not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = PeerConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
yaml string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
yaml: `
|
||||||
|
proxy: http://192.168.1.23
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
- model_b
|
||||||
|
`,
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config with apiKey",
|
||||||
|
yaml: `
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: sk-test-key
|
||||||
|
models:
|
||||||
|
- meta-llama/llama-3.1-8b-instruct
|
||||||
|
`,
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing proxy",
|
||||||
|
yaml: `
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
`,
|
||||||
|
wantErr: "proxy is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty proxy",
|
||||||
|
yaml: `
|
||||||
|
proxy: ""
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
`,
|
||||||
|
wantErr: "proxy is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid proxy URL",
|
||||||
|
yaml: `
|
||||||
|
proxy: "://invalid"
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
`,
|
||||||
|
wantErr: "invalid peer proxy URL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing models",
|
||||||
|
yaml: `
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`,
|
||||||
|
wantErr: "peer models can not be empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty models",
|
||||||
|
yaml: `
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
models: []
|
||||||
|
`,
|
||||||
|
wantErr: "peer models can not be empty",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var config PeerConfig
|
||||||
|
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||||
|
|
||||||
|
if tt.wantErr == "" {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||||
|
} else if !contains(err.Error(), tt.wantErr) {
|
||||||
|
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||||
|
yamlData := `
|
||||||
|
proxy: http://192.168.1.23:8080/api
|
||||||
|
apiKey: sk-test
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
`
|
||||||
|
var config PeerConfig
|
||||||
|
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL == nil {
|
||||||
|
t.Fatal("ProxyURL should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||||
|
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL.Scheme != "http" {
|
||||||
|
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL.Path != "/api" {
|
||||||
|
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func searchSubstring(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||||
|
yamlData := `
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: sk-test
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
filters:
|
||||||
|
setParams:
|
||||||
|
temperature: 0.7
|
||||||
|
provider:
|
||||||
|
data_collection: deny
|
||||||
|
`
|
||||||
|
var config PeerConfig
|
||||||
|
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Filters.SetParams == nil {
|
||||||
|
t.Fatal("Filters.SetParams should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||||
|
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("provider should be a map")
|
||||||
|
}
|
||||||
|
if provider["data_collection"] != "deny" {
|
||||||
|
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||||
|
yamlData := `
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: sk-test
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
filters:
|
||||||
|
stripParams: "temperature, top_p"
|
||||||
|
setParams:
|
||||||
|
max_tokens: 1000
|
||||||
|
`
|
||||||
|
var config PeerConfig
|
||||||
|
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check stripParams
|
||||||
|
stripParams := config.Filters.SanitizedStripParams()
|
||||||
|
if len(stripParams) != 2 {
|
||||||
|
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||||
|
}
|
||||||
|
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||||
|
t.Errorf("unexpected strip params: %v", stripParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check setParams
|
||||||
|
if config.Filters.SetParams == nil {
|
||||||
|
t.Fatal("Filters.SetParams should not be nil")
|
||||||
|
}
|
||||||
|
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||||
|
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
// Package configwatcher provides a simple cross-platform file watcher based
|
||||||
|
// on os.Stat polling. It works correctly inside Docker containers where the
|
||||||
|
// config file is bind-mounted as an individual file, and for k8s ConfigMap
|
||||||
|
// projections (which present the file as a symlink to an atomically swapped
|
||||||
|
// target) — both cases where inotify-based watchers are unreliable.
|
||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultInterval = 2 * time.Second
|
||||||
|
|
||||||
|
type Watcher struct {
|
||||||
|
Path string
|
||||||
|
Interval time.Duration
|
||||||
|
OnChange func()
|
||||||
|
}
|
||||||
|
|
||||||
|
type snapshot struct {
|
||||||
|
exists bool
|
||||||
|
modTime time.Time
|
||||||
|
size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||||
|
// OnChange whenever the file's modification time or size changes, or when
|
||||||
|
// the file reappears after being missing. The baseline poll establishes
|
||||||
|
// initial state and does not fire OnChange.
|
||||||
|
func (w *Watcher) Run(ctx context.Context) {
|
||||||
|
interval := w.Interval
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = DefaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
prev := stat(w.Path)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cur := stat(w.Path)
|
||||||
|
if changed(prev, cur) && w.OnChange != nil {
|
||||||
|
w.OnChange()
|
||||||
|
}
|
||||||
|
prev = cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stat(path string) snapshot {
|
||||||
|
fi, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
log.Printf("configwatcher: stat %s: %v", path, err)
|
||||||
|
}
|
||||||
|
return snapshot{}
|
||||||
|
}
|
||||||
|
return snapshot{
|
||||||
|
exists: true,
|
||||||
|
modTime: fi.ModTime(),
|
||||||
|
size: fi.Size(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func changed(prev, cur snapshot) bool {
|
||||||
|
// Present → missing: stay quiet (likely a transient rename-style write).
|
||||||
|
// Missing → present: fire so we reload as soon as the file comes back.
|
||||||
|
if !cur.exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !prev.exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !prev.modTime.Equal(cur.modTime) || prev.size != cur.size
|
||||||
|
}
|
||||||
@@ -0,0 +1,191 @@
|
|||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testInterval = 25 * time.Millisecond
|
||||||
|
|
||||||
|
// startWatcher launches w.Run in a goroutine and returns a function that
|
||||||
|
// cancels the context and waits for Run to return.
|
||||||
|
func startWatcher(t *testing.T, w *Watcher) func() {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.Run(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("watcher did not stop within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForCount blocks until counter reaches want or timeout elapses.
|
||||||
|
func waitForCount(t *testing.T, counter *int64, want int64, timeout time.Duration) bool {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if atomic.LoadInt64(counter) >= want {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_NoFireOnBaseline(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: path,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 5)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_DetectsModTimeChange(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
// Force a known baseline mtime.
|
||||||
|
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||||
|
require.NoError(t, os.Chtimes(path, base, base))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: path,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
// Let the baseline settle.
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Bump mtime well above the baseline so low-resolution filesystems still notice.
|
||||||
|
require.NoError(t, os.Chtimes(path, base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||||
|
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_DetectsSizeChangeWithSameModTime(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
fi, err := os.Stat(path)
|
||||||
|
require.NoError(t, err)
|
||||||
|
originalMtime := fi.ModTime()
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: path,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("aaaaa"), 0o644))
|
||||||
|
// Reset mtime back to the original so size is the only signal.
|
||||||
|
require.NoError(t, os.Chtimes(path, originalMtime, originalMtime))
|
||||||
|
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire on size change")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_SymlinkTargetSwap(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
targetA := filepath.Join(dir, "targetA")
|
||||||
|
targetB := filepath.Join(dir, "targetB")
|
||||||
|
link := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(targetA, []byte("AAAA"), 0o644))
|
||||||
|
require.NoError(t, os.WriteFile(targetB, []byte("BBBBBBBB"), 0o644))
|
||||||
|
|
||||||
|
if err := os.Symlink(targetA, link); err != nil {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skipf("symlink creation requires privilege on Windows: %v", err)
|
||||||
|
}
|
||||||
|
t.Fatalf("os.Symlink: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: link,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Atomic symlink swap (k8s ConfigMap pattern): create new symlink at a
|
||||||
|
// temp name, then rename over the existing one.
|
||||||
|
tmpLink := filepath.Join(dir, "config.yaml.tmp")
|
||||||
|
require.NoError(t, os.Symlink(targetB, tmpLink))
|
||||||
|
require.NoError(t, os.Rename(tmpLink, link))
|
||||||
|
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after symlink target swap")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_FileMissingThenReturns(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: path,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.Remove(path))
|
||||||
|
time.Sleep(testInterval * 3)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "removal alone must not fire")
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("b"), 0o644))
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when file returns")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
w := &Watcher{Path: path, Interval: testInterval}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() { w.Run(ctx); close(done) }()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Run did not return within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ const ConfigFileChangedEventID = 0x03
|
|||||||
const LogDataEventID = 0x04
|
const LogDataEventID = 0x04
|
||||||
const TokenMetricsEventID = 0x05
|
const TokenMetricsEventID = 0x05
|
||||||
const ModelPreloadedEventID = 0x06
|
const ModelPreloadedEventID = 0x06
|
||||||
|
const InFlightRequestsEventID = 0x07
|
||||||
|
|
||||||
type ProcessStateChangeEvent struct {
|
type ProcessStateChangeEvent struct {
|
||||||
ProcessName string
|
ProcessName string
|
||||||
@@ -58,3 +59,11 @@ type ModelPreloadedEvent struct {
|
|||||||
func (e ModelPreloadedEvent) Type() uint32 {
|
func (e ModelPreloadedEvent) Type() uint32 {
|
||||||
return ModelPreloadedEventID
|
return ModelPreloadedEventID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InFlightRequestsEvent struct {
|
||||||
|
Total int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e InFlightRequestsEvent) Type() uint32 {
|
||||||
|
return InFlightRequestsEventID
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,15 +1,22 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,16 +73,30 @@ func getTestPort() int {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
|
||||||
|
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
|
||||||
|
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||||
|
t.Helper()
|
||||||
|
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||||
|
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||||
|
// Convert path to forward slashes for cross-platform compatibility
|
||||||
|
// Windows handles forward slashes in paths correctly
|
||||||
|
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||||
|
|
||||||
// Create a YAML string with just the values we want to set
|
// Create a YAML string with just the values we want to set
|
||||||
yamlStr := fmt.Sprintf(`
|
yamlStr := fmt.Sprintf(`
|
||||||
cmd: '%s --port %d --silent --respond %s'
|
cmd: '%s --port %d --silent --respond %s'
|
||||||
proxy: "http://127.0.0.1:%d"
|
proxy: "http://127.0.0.1:%d"
|
||||||
`, simpleResponderPath, port, expectedMessage, port)
|
`, cmdPath, port, expectedMessage, port)
|
||||||
|
|
||||||
var cfg config.ModelConfig
|
var cfg config.ModelConfig
|
||||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||||
@@ -84,3 +105,188 @@ proxy: "http://127.0.0.1:%d"
|
|||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
|
||||||
|
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
|
||||||
|
// model IDs to their respond strings; if a model ID is not in the map, the model
|
||||||
|
// ID itself is used.
|
||||||
|
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
|
||||||
|
for _, pg := range pm.processGroups {
|
||||||
|
for modelID, process := range pg.processes {
|
||||||
|
respond := modelID
|
||||||
|
if r, ok := modelResponses[modelID]; ok {
|
||||||
|
respond = r
|
||||||
|
}
|
||||||
|
process.testHandler = newTestHandler(respond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestHandler returns an http.Handler that mimics simple-responder's API.
|
||||||
|
// It supports the endpoints that routing tests depend on, without launching
|
||||||
|
// any subprocess or binding any port.
|
||||||
|
func newTestHandler(respond string) http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
bodyBytes, _ := io.ReadAll(r.Body)
|
||||||
|
isStreaming := r.URL.Query().Get("stream") == "true"
|
||||||
|
|
||||||
|
if wait := r.URL.Query().Get("wait"); wait != "" {
|
||||||
|
if d, err := time.ParseDuration(wait); err == nil {
|
||||||
|
time.Sleep(d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isStreaming {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
flusher := w.(http.Flusher)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data, _ := json.Marshal(map[string]any{
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
finalData, _ := json.Marshal(map[string]any{
|
||||||
|
"usage": map[string]any{
|
||||||
|
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||||
|
},
|
||||||
|
"timings": map[string]any{
|
||||||
|
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||||
|
"predicted_ms": 17, "predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
|
||||||
|
flusher.Flush()
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"responseMessage": respond,
|
||||||
|
"h_content_length": r.Header.Get("Content-Length"),
|
||||||
|
"request_body": string(bodyBytes),
|
||||||
|
"usage": map[string]any{
|
||||||
|
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||||
|
},
|
||||||
|
"timings": map[string]any{
|
||||||
|
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||||
|
"predicted_ms": 17, "predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
if modelName != respond {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"responseMessage": respond,
|
||||||
|
"usage": map[string]any{
|
||||||
|
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"responseMessage": respond,
|
||||||
|
"usage": map[string]any{
|
||||||
|
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model := r.FormValue("model")
|
||||||
|
if model == "" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, _, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fileBytes, _ := io.ReadAll(file)
|
||||||
|
file.Close()
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
|
||||||
|
"model": model,
|
||||||
|
"h_content_type": r.Header.Get("Content-Type"),
|
||||||
|
"h_content_length": r.Header.Get("Content-Length"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
model := r.URL.Query().Get("model")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"voices": []string{"voice1"}, "model": model,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
fmt.Fprint(w, respond)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"model": modelName, "images": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"model": modelName, "images": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"loras": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/ring"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -12,6 +11,85 @@ import (
|
|||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// circularBuffer is a fixed-size circular byte buffer that overwrites
|
||||||
|
// oldest data when full. It provides O(1) writes and O(n) reads.
|
||||||
|
type circularBuffer struct {
|
||||||
|
data []byte // pre-allocated capacity
|
||||||
|
head int // next write position
|
||||||
|
size int // current number of bytes stored (0 to cap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCircularBuffer(capacity int) *circularBuffer {
|
||||||
|
return &circularBuffer{
|
||||||
|
data: make([]byte, capacity),
|
||||||
|
head: 0,
|
||||||
|
size: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write appends bytes to the buffer, overwriting oldest data when full.
|
||||||
|
// Data is copied into the internal buffer (not stored by reference).
|
||||||
|
func (cb *circularBuffer) Write(p []byte) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cap := len(cb.data)
|
||||||
|
|
||||||
|
// If input is larger than capacity, only keep the last cap bytes
|
||||||
|
if len(p) >= cap {
|
||||||
|
copy(cb.data, p[len(p)-cap:])
|
||||||
|
cb.head = 0
|
||||||
|
cb.size = cap
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate how much space is available from head to end of buffer
|
||||||
|
firstPart := cap - cb.head
|
||||||
|
if firstPart >= len(p) {
|
||||||
|
// All data fits without wrapping
|
||||||
|
copy(cb.data[cb.head:], p)
|
||||||
|
cb.head = (cb.head + len(p)) % cap
|
||||||
|
} else {
|
||||||
|
// Data wraps around
|
||||||
|
copy(cb.data[cb.head:], p[:firstPart])
|
||||||
|
copy(cb.data[:len(p)-firstPart], p[firstPart:])
|
||||||
|
cb.head = len(p) - firstPart
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update size
|
||||||
|
cb.size += len(p)
|
||||||
|
if cb.size > cap {
|
||||||
|
cb.size = cap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistory returns all buffered data in correct order (oldest to newest).
|
||||||
|
// Returns a new slice (copy), not a view into internal buffer.
|
||||||
|
func (cb *circularBuffer) GetHistory() []byte {
|
||||||
|
if cb.size == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]byte, cb.size)
|
||||||
|
cap := len(cb.data)
|
||||||
|
|
||||||
|
// Calculate start position (oldest data)
|
||||||
|
start := (cb.head - cb.size + cap) % cap
|
||||||
|
|
||||||
|
if start+cb.size <= cap {
|
||||||
|
// Data is contiguous, single copy
|
||||||
|
copy(result, cb.data[start:start+cb.size])
|
||||||
|
} else {
|
||||||
|
// Data wraps around, two copies
|
||||||
|
firstPart := cap - start
|
||||||
|
copy(result[:firstPart], cb.data[start:])
|
||||||
|
copy(result[firstPart:], cb.data[:cb.size-firstPart])
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
type LogLevel int
|
type LogLevel int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -19,12 +97,14 @@ const (
|
|||||||
LevelInfo
|
LevelInfo
|
||||||
LevelWarn
|
LevelWarn
|
||||||
LevelError
|
LevelError
|
||||||
|
|
||||||
|
LogBufferSize = 100 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogMonitor struct {
|
type LogMonitor struct {
|
||||||
eventbus *event.Dispatcher
|
eventbus *event.Dispatcher
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
buffer *ring.Ring
|
buffer *circularBuffer
|
||||||
bufferMu sync.RWMutex
|
bufferMu sync.RWMutex
|
||||||
|
|
||||||
// typically this can be os.Stdout
|
// typically this can be os.Stdout
|
||||||
@@ -45,7 +125,7 @@ func NewLogMonitor() *LogMonitor {
|
|||||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||||
return &LogMonitor{
|
return &LogMonitor{
|
||||||
eventbus: event.NewDispatcherConfig(1000),
|
eventbus: event.NewDispatcherConfig(1000),
|
||||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
buffer: nil, // lazy initialized on first Write
|
||||||
stdout: stdout,
|
stdout: stdout,
|
||||||
level: LevelInfo,
|
level: LevelInfo,
|
||||||
prefix: "",
|
prefix: "",
|
||||||
@@ -64,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.bufferMu.Lock()
|
w.bufferMu.Lock()
|
||||||
bufferCopy := make([]byte, len(p))
|
if w.buffer == nil {
|
||||||
copy(bufferCopy, p)
|
w.buffer = newCircularBuffer(LogBufferSize)
|
||||||
w.buffer.Value = bufferCopy
|
}
|
||||||
w.buffer = w.buffer.Next()
|
w.buffer.Write(p)
|
||||||
w.bufferMu.Unlock()
|
w.bufferMu.Unlock()
|
||||||
|
|
||||||
|
// Make a copy for broadcast to preserve immutability
|
||||||
|
bufferCopy := make([]byte, len(p))
|
||||||
|
copy(bufferCopy, p)
|
||||||
w.broadcast(bufferCopy)
|
w.broadcast(bufferCopy)
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
@@ -77,16 +160,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
|||||||
func (w *LogMonitor) GetHistory() []byte {
|
func (w *LogMonitor) GetHistory() []byte {
|
||||||
w.bufferMu.RLock()
|
w.bufferMu.RLock()
|
||||||
defer w.bufferMu.RUnlock()
|
defer w.bufferMu.RUnlock()
|
||||||
|
if w.buffer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.buffer.GetHistory()
|
||||||
|
}
|
||||||
|
|
||||||
var history []byte
|
// Clear releases the buffer memory, making it eligible for GC.
|
||||||
w.buffer.Do(func(p any) {
|
// The buffer will be lazily re-allocated on the next Write.
|
||||||
if p != nil {
|
func (w *LogMonitor) Clear() {
|
||||||
if content, ok := p.([]byte); ok {
|
w.bufferMu.Lock()
|
||||||
history = append(history, content...)
|
w.buffer = nil
|
||||||
}
|
w.bufferMu.Unlock()
|
||||||
}
|
|
||||||
})
|
|
||||||
return history
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||||
|
|||||||
@@ -113,3 +113,204 @@ func TestWrite_LogTimeFormat(t *testing.T) {
|
|||||||
t.Fatalf("Cannot find timestamp: %v", err)
|
t.Fatalf("Cannot find timestamp: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCircularBuffer_WrapAround(t *testing.T) {
|
||||||
|
// Create a small buffer to test wrap-around
|
||||||
|
cb := newCircularBuffer(10)
|
||||||
|
|
||||||
|
// Write "hello" (5 bytes)
|
||||||
|
cb.Write([]byte("hello"))
|
||||||
|
if got := string(cb.GetHistory()); got != "hello" {
|
||||||
|
t.Errorf("Expected 'hello', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write "world" (5 bytes) - buffer now full
|
||||||
|
cb.Write([]byte("world"))
|
||||||
|
if got := string(cb.GetHistory()); got != "helloworld" {
|
||||||
|
t.Errorf("Expected 'helloworld', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write "12345" (5 bytes) - should overwrite "hello"
|
||||||
|
cb.Write([]byte("12345"))
|
||||||
|
if got := string(cb.GetHistory()); got != "world12345" {
|
||||||
|
t.Errorf("Expected 'world12345', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write data larger than buffer capacity
|
||||||
|
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
|
||||||
|
if got := string(cb.GetHistory()); got != "ghijklmnop" {
|
||||||
|
t.Errorf("Expected 'ghijklmnop', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
|
||||||
|
// Test empty buffer
|
||||||
|
cb := newCircularBuffer(10)
|
||||||
|
if got := cb.GetHistory(); got != nil {
|
||||||
|
t.Errorf("Expected nil for empty buffer, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test exact capacity
|
||||||
|
cb.Write([]byte("1234567890"))
|
||||||
|
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||||
|
t.Errorf("Expected '1234567890', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test write exactly at capacity boundary
|
||||||
|
cb = newCircularBuffer(10)
|
||||||
|
cb.Write([]byte("12345"))
|
||||||
|
cb.Write([]byte("67890"))
|
||||||
|
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||||
|
t.Errorf("Expected '1234567890', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogMonitor_LazyInit(t *testing.T) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Buffer should be nil before any writes
|
||||||
|
if lm.buffer != nil {
|
||||||
|
t.Error("Expected buffer to be nil before first write")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistory should return nil when buffer is nil
|
||||||
|
if got := lm.GetHistory(); got != nil {
|
||||||
|
t.Errorf("Expected nil history before first write, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write should lazily initialize the buffer
|
||||||
|
lm.Write([]byte("test"))
|
||||||
|
|
||||||
|
if lm.buffer == nil {
|
||||||
|
t.Error("Expected buffer to be initialized after write")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := string(lm.GetHistory()); got != "test" {
|
||||||
|
t.Errorf("Expected 'test', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogMonitor_Clear(t *testing.T) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Write some data
|
||||||
|
lm.Write([]byte("hello"))
|
||||||
|
if got := string(lm.GetHistory()); got != "hello" {
|
||||||
|
t.Errorf("Expected 'hello', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear should release the buffer
|
||||||
|
lm.Clear()
|
||||||
|
|
||||||
|
if lm.buffer != nil {
|
||||||
|
t.Error("Expected buffer to be nil after Clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := lm.GetHistory(); got != nil {
|
||||||
|
t.Errorf("Expected nil history after Clear, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogMonitor_ClearAndReuse(t *testing.T) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Write, clear, then write again
|
||||||
|
lm.Write([]byte("first"))
|
||||||
|
lm.Clear()
|
||||||
|
lm.Write([]byte("second"))
|
||||||
|
|
||||||
|
if got := string(lm.GetHistory()); got != "second" {
|
||||||
|
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogMonitorWrite(b *testing.B) {
|
||||||
|
// Test data of varying sizes
|
||||||
|
smallMsg := []byte("small message\n")
|
||||||
|
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
|
||||||
|
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
|
||||||
|
|
||||||
|
b.Run("SmallWrite", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.Write(smallMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("MediumWrite", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.Write(mediumMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("LargeWrite", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.Write(largeMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WithSubscribers", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
// Add some subscribers
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
lm.OnLogData(func(data []byte) {})
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.Write(mediumMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("GetHistory", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
// Pre-populate with data
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
lm.Write(mediumMsg)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.GetHistory()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Benchmark Results - MBP M1 Pro
|
||||||
|
|
||||||
|
Before (ring.Ring):
|
||||||
|
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||||
|
|---------------------------------|------------|----------|-----------|
|
||||||
|
| SmallWrite (14B) | 43 ns | 40 B | 2 |
|
||||||
|
| MediumWrite (241B) | 76 ns | 264 B | 2 |
|
||||||
|
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
|
||||||
|
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
|
||||||
|
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
|
||||||
|
|
||||||
|
After (circularBuffer 10KB):
|
||||||
|
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||||
|
|---------------------------------|------------|----------|-----------|
|
||||||
|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||||
|
| MediumWrite (241B) | 67 ns | 240 B | 1 |
|
||||||
|
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
|
||||||
|
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
|
||||||
|
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
|
||||||
|
|
||||||
|
After (circularBuffer 100KB):
|
||||||
|
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||||
|
|---------------------------------|------------|-----------|-----------|
|
||||||
|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||||
|
| MediumWrite (241B) | 66 ns | 240 B | 1 |
|
||||||
|
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
|
||||||
|
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
|
||||||
|
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
|
||||||
|
|
||||||
|
Summary:
|
||||||
|
- GetHistory: 139x faster (10KB), 18x faster (100KB)
|
||||||
|
- Allocations: reduced from 2 to 1 across all operations
|
||||||
|
- Small/medium writes: ~1.1-1.6x faster
|
||||||
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,329 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||||
|
// It is safe for concurrent reads after construction.
|
||||||
|
type MatrixSolver struct {
|
||||||
|
expandedSets []config.ExpandedSet // all valid model combinations
|
||||||
|
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||||
|
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
|
||||||
|
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
|
||||||
|
modelToSets := make(map[string][]int)
|
||||||
|
for i, es := range expandedSets {
|
||||||
|
for _, model := range es.Models {
|
||||||
|
modelToSets[model] = append(modelToSets[model], i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MatrixSolver{
|
||||||
|
expandedSets: expandedSets,
|
||||||
|
evictCosts: evictCosts,
|
||||||
|
modelToSets: modelToSets,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SolveResult describes what the solver decided.
|
||||||
|
type SolveResult struct {
|
||||||
|
Evict []string // running models that must be stopped
|
||||||
|
TargetSet []string // the chosen set of models (for informational purposes)
|
||||||
|
SetName string // name of the chosen set
|
||||||
|
DSL string // original DSL expression for the chosen set
|
||||||
|
TotalCost int // total eviction cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve determines which models to evict when a model is requested.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// 1. If requestedModel is already running, no eviction needed.
|
||||||
|
// 2. Find all sets containing requestedModel.
|
||||||
|
// 3. If no sets found, the model runs alone; evict all running models.
|
||||||
|
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||||
|
// models NOT in that set.
|
||||||
|
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||||
|
// 6. Return models to evict and the chosen set.
|
||||||
|
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
|
||||||
|
// If already running, nothing to do (but fill in set info for logging)
|
||||||
|
if slices.Contains(runningModels, requestedModel) {
|
||||||
|
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||||
|
return SolveResult{
|
||||||
|
TargetSet: runningModels,
|
||||||
|
SetName: setName,
|
||||||
|
DSL: dsl,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
candidateIndices := s.modelToSets[requestedModel]
|
||||||
|
|
||||||
|
// Model not in any set: runs alone, evict everything
|
||||||
|
if len(candidateIndices) == 0 {
|
||||||
|
evict := make([]string, len(runningModels))
|
||||||
|
copy(evict, runningModels)
|
||||||
|
return SolveResult{
|
||||||
|
Evict: evict,
|
||||||
|
TargetSet: []string{requestedModel},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the cheapest candidate set
|
||||||
|
bestCost := -1
|
||||||
|
bestIdx := -1
|
||||||
|
|
||||||
|
for _, idx := range candidateIndices {
|
||||||
|
setModels := s.expandedSets[idx].Models
|
||||||
|
cost := 0
|
||||||
|
for _, running := range runningModels {
|
||||||
|
if !slices.Contains(setModels, running) {
|
||||||
|
cost += s.evictCost(running)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||||
|
bestCost = cost
|
||||||
|
bestIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine which running models to evict
|
||||||
|
chosen := s.expandedSets[bestIdx]
|
||||||
|
var evict []string
|
||||||
|
for _, running := range runningModels {
|
||||||
|
if !slices.Contains(chosen.Models, running) {
|
||||||
|
evict = append(evict, running)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return SolveResult{
|
||||||
|
Evict: evict,
|
||||||
|
TargetSet: chosen.Models,
|
||||||
|
SetName: chosen.SetName,
|
||||||
|
DSL: chosen.DSL,
|
||||||
|
TotalCost: bestCost,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findMatchingSet finds the expanded set that contains all running models.
|
||||||
|
// Returns the set name and DSL, or empty strings if no match.
|
||||||
|
func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||||
|
for _, idx := range s.modelToSets[requestedModel] {
|
||||||
|
set := s.expandedSets[idx]
|
||||||
|
allInSet := true
|
||||||
|
for _, m := range runningModels {
|
||||||
|
if !slices.Contains(set.Models, m) {
|
||||||
|
allInSet = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if allInSet {
|
||||||
|
return set.SetName, set.DSL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MatrixSolver) evictCost(model string) int {
|
||||||
|
if cost, ok := s.evictCosts[model]; ok {
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matrix manages processes using solver-based swap logic.
|
||||||
|
type Matrix struct {
|
||||||
|
sync.Mutex
|
||||||
|
solver *MatrixSolver
|
||||||
|
processes map[string]*Process // all processes keyed by real model name
|
||||||
|
config config.Config
|
||||||
|
proxyLogger *LogMonitor
|
||||||
|
upstreamLogger *LogMonitor
|
||||||
|
|
||||||
|
// inflight tracks ProxyRequest calls that have released m.Lock but may
|
||||||
|
// not yet have incremented Process.inFlightRequests. A concurrent
|
||||||
|
// request that needs to evict models waits for inflight to drain under
|
||||||
|
// m.Lock before stopping anything. Without this, a request that
|
||||||
|
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
|
||||||
|
// races with Stop()'s Wait() and can be killed mid-request.
|
||||||
|
inflight sync.WaitGroup
|
||||||
|
|
||||||
|
// testDelayFastPath is a test-only hook invoked in the no-eviction path
|
||||||
|
// after m.Lock is released but before the request is dispatched to
|
||||||
|
// Process.ProxyRequest. Tests use it to park a request at the exact
|
||||||
|
// race window to deterministically reproduce the race.
|
||||||
|
testDelayFastPath func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMatrix creates a Matrix from config. It creates a Process for every
|
||||||
|
// model defined in the config (any model can run alone even if not in a set).
|
||||||
|
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *LogMonitor) *Matrix {
|
||||||
|
processes := make(map[string]*Process)
|
||||||
|
for modelID, modelConfig := range cfg.Models {
|
||||||
|
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||||
|
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
|
||||||
|
processes[modelID] = process
|
||||||
|
}
|
||||||
|
|
||||||
|
evictCosts := cfg.Matrix.ResolvedEvictCosts()
|
||||||
|
|
||||||
|
return &Matrix{
|
||||||
|
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
|
||||||
|
processes: processes,
|
||||||
|
config: cfg,
|
||||||
|
proxyLogger: proxyLogger,
|
||||||
|
upstreamLogger: upstreamLogger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyRequest handles the swap logic and proxies the request to the model.
|
||||||
|
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
process, ok := m.processes[modelID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("model %s not found in matrix", modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Lock()
|
||||||
|
running := m.runningModels()
|
||||||
|
result, err := m.solver.Solve(modelID, running)
|
||||||
|
if err != nil {
|
||||||
|
m.Unlock()
|
||||||
|
return fmt.Errorf("matrix solver error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log solver decision
|
||||||
|
if len(result.Evict) > 0 {
|
||||||
|
m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||||
|
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||||
|
} else if len(running) == 0 {
|
||||||
|
m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
|
||||||
|
} else {
|
||||||
|
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict models that need to be stopped
|
||||||
|
if len(result.Evict) > 0 {
|
||||||
|
// Wait for any in-flight ProxyRequest calls to register on their
|
||||||
|
// Process before stopping anything. Without this, a request that
|
||||||
|
// released m.Lock but has not yet incremented
|
||||||
|
// Process.inFlightRequests races with Stop() and can be killed
|
||||||
|
// mid-request.
|
||||||
|
m.inflight.Wait()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, evictModel := range result.Evict {
|
||||||
|
if p, exists := m.processes[evictModel]; exists {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(p *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
p.Stop()
|
||||||
|
}(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register this request in inflight before releasing m.Lock so a
|
||||||
|
// concurrent eviction will wait for it to complete.
|
||||||
|
m.inflight.Add(1)
|
||||||
|
defer m.inflight.Done()
|
||||||
|
isFastPath := len(result.Evict) == 0
|
||||||
|
m.Unlock()
|
||||||
|
|
||||||
|
if isFastPath && m.testDelayFastPath != nil {
|
||||||
|
m.testDelayFastPath()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy the request (Process handles on-demand start)
|
||||||
|
process.ProxyRequest(w, r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopProcesses stops all running processes.
|
||||||
|
func (m *Matrix) StopProcesses(strategy StopStrategy) {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range m.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(p *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
p.StopImmediately()
|
||||||
|
default:
|
||||||
|
p.Stop()
|
||||||
|
}
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopProcess stops a single process by model ID.
|
||||||
|
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
|
||||||
|
process, ok := m.processes[modelID]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("process not found for %s", modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown shuts down all processes.
|
||||||
|
func (m *Matrix) Shutdown() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range m.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(p *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
p.Shutdown()
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunningModels returns model names currently in an active (non-stopped) state.
|
||||||
|
func (m *Matrix) RunningModels() []string {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
return m.runningModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runningModels returns running model names (caller must hold lock).
|
||||||
|
func (m *Matrix) runningModels() []string {
|
||||||
|
var running []string
|
||||||
|
for id, process := range m.processes {
|
||||||
|
if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
|
||||||
|
running = append(running, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(running)
|
||||||
|
return running
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProcess returns the Process for a model.
|
||||||
|
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
|
||||||
|
p, ok := m.processes[modelID]
|
||||||
|
return p, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasModel returns true if the model is managed by this matrix.
|
||||||
|
func (m *Matrix) HasModel(modelID string) bool {
|
||||||
|
_, ok := m.processes[modelID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
@@ -0,0 +1,349 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper to build expanded sets for solver tests
|
||||||
|
func makeExpandedSets(sets ...struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
}) []config.ExpandedSet {
|
||||||
|
var result []config.ExpandedSet
|
||||||
|
for _, s := range sets {
|
||||||
|
result = append(result, config.ExpandedSet{
|
||||||
|
SetName: s.name,
|
||||||
|
Models: s.models,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func es(name string, models ...string) struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
} {
|
||||||
|
return struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
}{name, models}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(es("s1", "a", "b")),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("a", []string{"a"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a"}, result.TargetSet)
|
||||||
|
assert.Equal(t, "s1", result.SetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(es("s1", "a", "b")),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Model "c" not in any set
|
||||||
|
result, err := solver.Solve("c", []string{"a", "b"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(es("s1", "a", "b")),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("c", []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
|
||||||
|
// Set: [a, b]. Request a when b and c are running.
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(es("s1", "a", "b")),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("a", []string{"b", "c"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
// c is not in the set, so it gets evicted. b is in the set, so it stays.
|
||||||
|
assert.Equal(t, []string{"c"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
|
||||||
|
// Two sets containing model "a":
|
||||||
|
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
|
||||||
|
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("s1", "a", "v"),
|
||||||
|
es("s2", "a", "L"),
|
||||||
|
),
|
||||||
|
map[string]int{"v": 50, "L": 30},
|
||||||
|
)
|
||||||
|
|
||||||
|
// v is running. Switching to a:
|
||||||
|
// s1 cost: v is in s1, so 0
|
||||||
|
// s2 cost: v is NOT in s2, so 50
|
||||||
|
// => pick s1
|
||||||
|
result, err := solver.Solve("a", []string{"v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
|
||||||
|
|
||||||
|
// L is running. Switching to a:
|
||||||
|
// s1 cost: L is NOT in s1, so 30
|
||||||
|
// s2 cost: L is in s2, so 0
|
||||||
|
// => pick s2
|
||||||
|
result, err = solver.Solve("a", []string{"L"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
|
||||||
|
// Two sets with identical cost. Definition order should win.
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("s1", "a", "x"),
|
||||||
|
es("s2", "a", "y"),
|
||||||
|
),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Nothing running, both sets cost 0. s1 is first.
|
||||||
|
result, err := solver.Solve("a", []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
|
||||||
|
// Model "v" costs 50 to evict, "m" costs 1 (default).
|
||||||
|
// Sets: [g,v], [g,m]
|
||||||
|
// Running: v, m. Request g.
|
||||||
|
// s1=[g,v]: evict m (cost 1), keep v
|
||||||
|
// s2=[g,m]: evict v (cost 50), keep m
|
||||||
|
// => pick s1
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("s1", "g", "v"),
|
||||||
|
es("s2", "g", "m"),
|
||||||
|
),
|
||||||
|
map[string]int{"v": 50},
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("g", []string{"v", "m"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"m"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_NothingRunning(t *testing.T) {
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("s1", "g", "v"),
|
||||||
|
es("s2", "q", "v"),
|
||||||
|
),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := solver.Solve("g", []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, result.Evict)
|
||||||
|
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
|
||||||
|
// cannot stop a process while an in-flight ProxyRequest for that process is
|
||||||
|
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
|
||||||
|
// matrix-level inflight tracking, the eviction's Stop() races with the
|
||||||
|
// pending request and kills it mid-start.
|
||||||
|
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
ExpandedSets: []config.ExpandedSet{
|
||||||
|
{SetName: "s1", Models: []string{"model1"}},
|
||||||
|
{SetName: "s2", Models: []string{"model2"}},
|
||||||
|
},
|
||||||
|
Matrix: &config.MatrixConfig{},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewMatrix(cfg, testLogger, testLogger)
|
||||||
|
defer m.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
// Bypass real subprocesses so the test is fast and deterministic.
|
||||||
|
m.processes["model1"].testHandler = newTestHandler("model1")
|
||||||
|
m.processes["model2"].testHandler = newTestHandler("model2")
|
||||||
|
|
||||||
|
// Prime: run a request through model1 so it reaches StateReady and
|
||||||
|
// subsequent requests take the no-eviction path.
|
||||||
|
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
primeW := httptest.NewRecorder()
|
||||||
|
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
|
||||||
|
require.Equal(t, http.StatusOK, primeW.Code)
|
||||||
|
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
|
||||||
|
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
|
||||||
|
|
||||||
|
// Install fast-path hook that signals arrival and waits for release.
|
||||||
|
// This parks R2 at the race window — after m.Lock is released but
|
||||||
|
// before Process.inFlightRequests.Add(1).
|
||||||
|
r2Reached := make(chan struct{})
|
||||||
|
r2Release := make(chan struct{})
|
||||||
|
m.testDelayFastPath = func() {
|
||||||
|
close(r2Reached)
|
||||||
|
<-r2Release
|
||||||
|
}
|
||||||
|
|
||||||
|
// R2: no-eviction request for model1. Will pause at the hook.
|
||||||
|
r2Done := make(chan struct{})
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r2Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, m.ProxyRequest("model1", w2, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Deterministically wait for R2 to reach the race window.
|
||||||
|
<-r2Reached
|
||||||
|
|
||||||
|
// R3: request for model2 which requires evicting model1. Must wait for
|
||||||
|
// R2 to finish before touching model1.
|
||||||
|
r3Done := make(chan struct{})
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r3Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, m.ProxyRequest("model2", w3, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Spin until R3 has acquired m.Lock and entered the eviction path. In
|
||||||
|
// the fixed code, R3 then blocks on m.inflight.Wait() while still
|
||||||
|
// holding the lock, so TryLock keeps failing.
|
||||||
|
for m.TryLock() {
|
||||||
|
m.Unlock()
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||||
|
// state. In the fixed code R3 is blocked and nothing changes; in the
|
||||||
|
// buggy code R3 will Stop() model1 and start model2 within microseconds.
|
||||||
|
deadline := time.Now().Add(100 * time.Millisecond)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if m.processes["model1"].CurrentState() != StateReady ||
|
||||||
|
m.processes["model2"].CurrentState() != StateStopped {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
done := false
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
done = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
|
||||||
|
"model1 must stay Ready while an in-flight request is pending")
|
||||||
|
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
|
||||||
|
"model2 must not be started until R2 finishes and model1 is evicted")
|
||||||
|
|
||||||
|
// Release R2 and let both requests finish.
|
||||||
|
close(r2Release)
|
||||||
|
<-r2Done
|
||||||
|
<-r3Done
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w2.Code)
|
||||||
|
assert.Contains(t, w2.Body.String(), "model1")
|
||||||
|
assert.Equal(t, http.StatusOK, w3.Code)
|
||||||
|
assert.Contains(t, w3.Body.String(), "model2")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrixSolver_FullScenario(t *testing.T) {
|
||||||
|
// Simulates the example config:
|
||||||
|
// standard: [g,v], [q,v], [m,v]
|
||||||
|
// with_rerank: [g,v,e], [q,v,e]
|
||||||
|
// creative: [g,sd], [q,sd]
|
||||||
|
// full: [L]
|
||||||
|
solver := NewMatrixSolver(
|
||||||
|
makeExpandedSets(
|
||||||
|
es("standard", "g", "v"),
|
||||||
|
es("standard", "q", "v"),
|
||||||
|
es("standard", "m", "v"),
|
||||||
|
es("with_rerank", "e", "g", "v"),
|
||||||
|
es("with_rerank", "e", "q", "v"),
|
||||||
|
es("creative", "g", "sd"),
|
||||||
|
es("creative", "q", "sd"),
|
||||||
|
es("full", "L"),
|
||||||
|
),
|
||||||
|
map[string]int{"v": 50, "L": 30, "whisper": 10},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Running: g, v. Request q.
|
||||||
|
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
|
||||||
|
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
|
||||||
|
// => tie, pick first by definition order = standard[q,v]
|
||||||
|
result, err := solver.Solve("q", []string{"g", "v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"g"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
|
||||||
|
|
||||||
|
// Running: g, v. Request L.
|
||||||
|
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||||
|
// Only one set contains L, so pick it.
|
||||||
|
result, err = solver.Solve("L", []string{"g", "v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"L"}, result.TargetSet)
|
||||||
|
|
||||||
|
// Running: g, v. Request sd.
|
||||||
|
// creative[g,sd]: evict v (cost 50). Total: 50.
|
||||||
|
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||||
|
// => pick creative[g,sd]
|
||||||
|
result, err = solver.Solve("sd", []string{"g", "v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"v"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
|
||||||
|
|
||||||
|
// Running: q, v, e. Request g.
|
||||||
|
// standard[g,v]: evict q (1) + e (1). Total: 2.
|
||||||
|
// with_rerank[g,v,e]: evict q (1). Total: 1.
|
||||||
|
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
|
||||||
|
// => pick with_rerank[g,v,e]
|
||||||
|
result, err = solver.Solve("g", []string{"e", "q", "v"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"q"}, result.Evict)
|
||||||
|
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
|
||||||
|
}
|
||||||
@@ -2,6 +2,8 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"compress/flate"
|
||||||
|
"compress/gzip"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -11,10 +13,54 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||||
|
var zstdEncOptions = []zstd.EOption{
|
||||||
|
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||||
|
}
|
||||||
|
|
||||||
|
// zstdDecOptions are the shared zstd decoder options.
|
||||||
|
var zstdDecOptions = []zstd.DOption{}
|
||||||
|
|
||||||
|
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||||
|
var zstdEncPool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||||
|
return enc
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||||
|
var zstdDecPool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
|
||||||
|
return dec
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
|
||||||
|
// Returns compressed bytes and the original JSON byte count for logging.
|
||||||
|
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||||
|
jsonBytes, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||||
|
}
|
||||||
|
enc := zstdEncPool.Get().(*zstd.Encoder)
|
||||||
|
defer zstdEncPool.Put(enc)
|
||||||
|
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompressCapture decompresses zstd-compressed JSON and returns it.
|
||||||
|
func decompressCapture(data []byte) ([]byte, error) {
|
||||||
|
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||||
|
defer zstdDecPool.Put(dec)
|
||||||
|
return dec.DecodeAll(data, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||||
type TokenMetrics struct {
|
type TokenMetrics struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
@@ -26,6 +72,16 @@ type TokenMetrics struct {
|
|||||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||||
DurationMs int `json:"duration_ms"`
|
DurationMs int `json:"duration_ms"`
|
||||||
|
HasCapture bool `json:"has_capture"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReqRespCapture struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
ReqPath string `json:"req_path"`
|
||||||
|
ReqHeaders map[string]string `json:"req_headers"`
|
||||||
|
ReqBody []byte `json:"req_body"`
|
||||||
|
RespHeaders map[string]string `json:"resp_headers"`
|
||||||
|
RespBody []byte `json:"resp_body"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenMetricsEvent represents a token metrics event
|
// TokenMetricsEvent represents a token metrics event
|
||||||
@@ -44,19 +100,32 @@ type metricsMonitor struct {
|
|||||||
maxMetrics int
|
maxMetrics int
|
||||||
nextID int
|
nextID int
|
||||||
logger *LogMonitor
|
logger *LogMonitor
|
||||||
|
|
||||||
|
// capture fields
|
||||||
|
enableCaptures bool
|
||||||
|
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
|
||||||
|
captureOrder []int // track insertion order for FIFO eviction
|
||||||
|
captureSize int // current total compressed size in bytes
|
||||||
|
maxCaptureSize int // max bytes for captures (uncompressed)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
|
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||||
mp := &metricsMonitor{
|
// capture buffer size in megabytes; 0 disables captures.
|
||||||
logger: logger,
|
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||||
maxMetrics: maxMetrics,
|
return &metricsMonitor{
|
||||||
|
logger: logger,
|
||||||
|
maxMetrics: maxMetrics,
|
||||||
|
enableCaptures: captureBufferMB > 0,
|
||||||
|
captures: make(map[int][]byte),
|
||||||
|
captureOrder: make([]int, 0),
|
||||||
|
captureSize: 0,
|
||||||
|
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
return mp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addMetrics adds a new metric to the collection and publishes an event
|
// addMetrics adds a new metric to the collection and publishes an event.
|
||||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
// Returns the assigned metric ID.
|
||||||
|
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||||
mp.mu.Lock()
|
mp.mu.Lock()
|
||||||
defer mp.mu.Unlock()
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
@@ -67,6 +136,84 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
|||||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||||
}
|
}
|
||||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||||
|
return metric.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||||
|
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
|
||||||
|
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||||
|
if !mp.enableCaptures {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
captureSize := len(compressed)
|
||||||
|
if captureSize > mp.maxCaptureSize {
|
||||||
|
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
|
||||||
|
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Evict oldest (FIFO) until room available for the compressed data
|
||||||
|
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||||
|
oldestID := mp.captureOrder[0]
|
||||||
|
mp.captureOrder = mp.captureOrder[1:]
|
||||||
|
if evicted, exists := mp.captures[oldestID]; exists {
|
||||||
|
l := len(evicted)
|
||||||
|
mp.captureSize -= l
|
||||||
|
delete(mp.captures, oldestID)
|
||||||
|
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.captures[capture.ID] = compressed
|
||||||
|
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||||
|
mp.captureSize += captureSize
|
||||||
|
|
||||||
|
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
||||||
|
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
data, exists := mp.captures[id]
|
||||||
|
return data, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
|
||||||
|
// If decompress=false, returns the raw zstd-compressed bytes.
|
||||||
|
// Returns nil if the capture is not found.
|
||||||
|
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
data, exists := mp.captures[id]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !decompress {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
decompressed, err := decompressCapture(data)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return decompressed
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMetrics returns a copy of the current metrics
|
// getMetrics returns a copy of the current metrics
|
||||||
@@ -95,7 +242,35 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
request *http.Request,
|
request *http.Request,
|
||||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||||
) error {
|
) error {
|
||||||
|
// Capture request body and headers if captures enabled
|
||||||
|
var reqBody []byte
|
||||||
|
var reqHeaders map[string]string
|
||||||
|
if mp.enableCaptures {
|
||||||
|
if request.Body != nil {
|
||||||
|
var err error
|
||||||
|
reqBody, err = io.ReadAll(request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||||
|
}
|
||||||
|
request.Body.Close()
|
||||||
|
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||||
|
}
|
||||||
|
reqHeaders = make(map[string]string)
|
||||||
|
for key, values := range request.Header {
|
||||||
|
if len(values) > 0 {
|
||||||
|
reqHeaders[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
redactHeaders(reqHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
recorder := newBodyCopier(writer)
|
recorder := newBodyCopier(writer)
|
||||||
|
|
||||||
|
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||||
|
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||||
|
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||||
|
}
|
||||||
|
|
||||||
if err := next(modelID, recorder, request); err != nil {
|
if err := next(modelID, recorder, request); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -108,30 +283,94 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize default metrics - these will always be recorded
|
||||||
|
tm := TokenMetrics{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: modelID,
|
||||||
|
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||||
|
}
|
||||||
|
|
||||||
body := recorder.body.Bytes()
|
body := recorder.body.Bytes()
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
mp.logger.Warn("metrics skipped, empty body")
|
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||||
|
mp.addMetrics(tm)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
// Decompress if needed
|
||||||
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||||
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
var err error
|
||||||
} else {
|
body, err = decompressBody(body, encoding)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
mp.addMetrics(tm)
|
mp.addMetrics(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||||
|
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||||
|
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
|
} else {
|
||||||
|
tm = parsed
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if gjson.ValidBytes(body) {
|
if gjson.ValidBytes(body) {
|
||||||
if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil {
|
parsed := gjson.ParseBytes(body)
|
||||||
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
usage := parsed.Get("usage")
|
||||||
} else {
|
timings := parsed.Get("timings")
|
||||||
mp.addMetrics(tm)
|
|
||||||
|
// extract timings for infill - response is an array, timings are in the last element
|
||||||
|
// see #463
|
||||||
|
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||||
|
if arr := parsed.Array(); len(arr) > 0 {
|
||||||
|
timings = arr[len(arr)-1].Get("timings")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.Exists() || timings.Exists() {
|
||||||
|
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||||
|
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
|
} else {
|
||||||
|
tm = parsedMetrics
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build capture if enabled and determine if it will be stored
|
||||||
|
var capture *ReqRespCapture
|
||||||
|
if mp.enableCaptures {
|
||||||
|
respHeaders := make(map[string]string)
|
||||||
|
for key, values := range recorder.Header() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
respHeaders[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
redactHeaders(respHeaders)
|
||||||
|
delete(respHeaders, "Content-Encoding")
|
||||||
|
capture = &ReqRespCapture{
|
||||||
|
ReqPath: request.URL.Path,
|
||||||
|
ReqHeaders: reqHeaders,
|
||||||
|
ReqBody: reqBody,
|
||||||
|
RespHeaders: respHeaders,
|
||||||
|
RespBody: body,
|
||||||
|
}
|
||||||
|
compressed, _, err := compressCapture(capture)
|
||||||
|
if err == nil && len(compressed) <= mp.maxCaptureSize {
|
||||||
|
tm.HasCapture = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
metricID := mp.addMetrics(tm)
|
||||||
|
|
||||||
|
// Store capture if enabled
|
||||||
|
if capture != nil {
|
||||||
|
capture.ID = metricID
|
||||||
|
mp.addCapture(*capture)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,19 +413,27 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
if gjson.ValidBytes(data) {
|
if gjson.ValidBytes(data) {
|
||||||
return parseMetrics(modelID, start, gjson.ParseBytes(data))
|
parsed := gjson.ParseBytes(data)
|
||||||
|
usage := parsed.Get("usage")
|
||||||
|
timings := parsed.Get("timings")
|
||||||
|
|
||||||
|
// v1/responses format nests usage under response.usage
|
||||||
|
if !usage.Exists() {
|
||||||
|
usage = parsed.Get("response.usage")
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.Exists() || timings.Exists() {
|
||||||
|
return parseMetrics(modelID, start, usage, timings)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) {
|
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
||||||
usage := jsonData.Get("usage")
|
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||||
timings := jsonData.Get("timings")
|
|
||||||
if !usage.Exists() && !timings.Exists() {
|
|
||||||
return TokenMetrics{}, fmt.Errorf("no usage or timings data found")
|
|
||||||
}
|
|
||||||
// default values
|
// default values
|
||||||
cachedTokens := -1 // unknown or missing data
|
cachedTokens := -1 // unknown or missing data
|
||||||
outputTokens := 0
|
outputTokens := 0
|
||||||
@@ -195,22 +442,41 @@ func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (Token
|
|||||||
// timings data
|
// timings data
|
||||||
tokensPerSecond := -1.0
|
tokensPerSecond := -1.0
|
||||||
promptPerSecond := -1.0
|
promptPerSecond := -1.0
|
||||||
durationMs := int(time.Since(start).Milliseconds())
|
durationMs := wallDurationMs
|
||||||
|
|
||||||
if usage.Exists() {
|
if usage.Exists() {
|
||||||
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
if pt := usage.Get("prompt_tokens"); pt.Exists() {
|
||||||
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
// v1/chat/completions
|
||||||
|
inputTokens = int(pt.Int())
|
||||||
|
} else if it := usage.Get("input_tokens"); it.Exists() {
|
||||||
|
// v1/messages
|
||||||
|
inputTokens = int(it.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct := usage.Get("completion_tokens"); ct.Exists() {
|
||||||
|
// v1/chat/completions
|
||||||
|
outputTokens = int(ct.Int())
|
||||||
|
} else if ot := usage.Get("output_tokens"); ot.Exists() {
|
||||||
|
outputTokens = int(ot.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
|
||||||
|
cachedTokens = int(ct.Int())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||||
if timings.Exists() {
|
if timings.Exists() {
|
||||||
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
inputTokens = int(timings.Get("prompt_n").Int())
|
||||||
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
outputTokens = int(timings.Get("predicted_n").Int())
|
||||||
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||||
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||||
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||||
|
if timingsDurationMs > durationMs {
|
||||||
|
durationMs = timingsDurationMs
|
||||||
|
}
|
||||||
|
|
||||||
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
|
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||||
cachedTokens = int(cachedValue.Int())
|
cachedTokens = int(cachedValue.Int())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -227,6 +493,25 @@ func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (Token
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decompressBody decompresses the body based on Content-Encoding header
|
||||||
|
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||||
|
case "gzip":
|
||||||
|
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
return io.ReadAll(reader)
|
||||||
|
case "deflate":
|
||||||
|
reader := flate.NewReader(bytes.NewReader(body))
|
||||||
|
defer reader.Close()
|
||||||
|
return io.ReadAll(reader)
|
||||||
|
default:
|
||||||
|
return body, nil // Return as-is for unknown/no encoding
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// responseBodyCopier records the response body and writes to the original response writer
|
// responseBodyCopier records the response body and writes to the original response writer
|
||||||
// while also capturing it in a buffer for later processing
|
// while also capturing it in a buffer for later processing
|
||||||
type responseBodyCopier struct {
|
type responseBodyCopier struct {
|
||||||
@@ -265,3 +550,43 @@ func (w *responseBodyCopier) Header() http.Header {
|
|||||||
func (w *responseBodyCopier) StartTime() time.Time {
|
func (w *responseBodyCopier) StartTime() time.Time {
|
||||||
return w.start
|
return w.start
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sensitiveHeaders lists headers that should be redacted in captures
|
||||||
|
var sensitiveHeaders = map[string]bool{
|
||||||
|
"authorization": true,
|
||||||
|
"proxy-authorization": true,
|
||||||
|
"cookie": true,
|
||||||
|
"set-cookie": true,
|
||||||
|
"x-api-key": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||||
|
func redactHeaders(headers map[string]string) {
|
||||||
|
for key := range headers {
|
||||||
|
if sensitiveHeaders[strings.ToLower(key)] {
|
||||||
|
headers[key] = "[REDACTED]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||||
|
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||||
|
// preferences while ensuring we can parse response bodies for metrics.
|
||||||
|
func filterAcceptEncoding(acceptEncoding string) string {
|
||||||
|
if acceptEncoding == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||||
|
var filtered []string
|
||||||
|
|
||||||
|
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||||
|
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||||
|
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||||
|
if supported[strings.ToLower(encoding)] {
|
||||||
|
filtered = append(filtered, strings.TrimSpace(part))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(filtered, ", ")
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/flate"
|
||||||
|
"compress/gzip"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -11,11 +15,12 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||||
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
metric := TokenMetrics{
|
metric := TokenMetrics{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -34,7 +39,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("increments ID for each metric", func(t *testing.T) {
|
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
mm.addMetrics(TokenMetrics{Model: "model"})
|
mm.addMetrics(TokenMetrics{Model: "model"})
|
||||||
@@ -48,7 +53,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("respects max metrics limit", func(t *testing.T) {
|
t.Run("respects max metrics limit", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 3)
|
mm := newMetricsMonitor(testLogger, 3, 0)
|
||||||
|
|
||||||
// Add 5 metrics
|
// Add 5 metrics
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
@@ -68,7 +73,7 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
receivedEvent := make(chan TokenMetricsEvent, 1)
|
receivedEvent := make(chan TokenMetricsEvent, 1)
|
||||||
cancel := event.On(func(e TokenMetricsEvent) {
|
cancel := event.On(func(e TokenMetricsEvent) {
|
||||||
@@ -98,14 +103,14 @@ func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
|||||||
|
|
||||||
func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||||
t.Run("returns empty slice when no metrics", func(t *testing.T) {
|
t.Run("returns empty slice when no metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.NotNil(t, metrics)
|
assert.NotNil(t, metrics)
|
||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 0, len(metrics))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns copy of metrics", func(t *testing.T) {
|
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
mm.addMetrics(TokenMetrics{Model: "model1"})
|
mm.addMetrics(TokenMetrics{Model: "model1"})
|
||||||
mm.addMetrics(TokenMetrics{Model: "model2"})
|
mm.addMetrics(TokenMetrics{Model: "model2"})
|
||||||
|
|
||||||
@@ -125,7 +130,7 @@ func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
|||||||
|
|
||||||
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||||
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
|
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
jsonData, err := mm.getMetricsJSON()
|
jsonData, err := mm.getMetricsJSON()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, jsonData)
|
assert.NotNil(t, jsonData)
|
||||||
@@ -137,7 +142,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
mm.addMetrics(TokenMetrics{
|
mm.addMetrics(TokenMetrics{
|
||||||
Model: "model1",
|
Model: "model1",
|
||||||
InputTokens: 100,
|
InputTokens: 100,
|
||||||
@@ -165,7 +170,7 @@ func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
|||||||
|
|
||||||
func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||||
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
|
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
responseBody := `{
|
responseBody := `{
|
||||||
"usage": {
|
"usage": {
|
||||||
@@ -196,7 +201,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("successful request with timings data", func(t *testing.T) {
|
t.Run("successful request with timings data", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
responseBody := `{
|
responseBody := `{
|
||||||
"timings": {
|
"timings": {
|
||||||
@@ -236,7 +241,7 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("streaming request with SSE format", func(t *testing.T) {
|
t.Run("streaming request with SSE format", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
// Note: SSE format requires proper line breaks - each data line followed by blank line
|
// Note: SSE format requires proper line breaks - each data line followed by blank line
|
||||||
responseBody := `data: {"choices":[{"text":"Hello"}]}
|
responseBody := `data: {"choices":[{"text":"Hello"}]}
|
||||||
@@ -272,7 +277,7 @@ data: [DONE]
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
@@ -291,8 +296,8 @@ data: [DONE]
|
|||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 0, len(metrics))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("empty response body does not record metrics", func(t *testing.T) {
|
t.Run("empty response body records minimal metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -307,11 +312,14 @@ data: [DONE]
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
|
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -328,11 +336,14 @@ data: [DONE]
|
|||||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("next handler error is propagated", func(t *testing.T) {
|
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
expectedErr := assert.AnError
|
expectedErr := assert.AnError
|
||||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
@@ -350,8 +361,8 @@ data: [DONE]
|
|||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 0, len(metrics))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
|
t.Run("response without usage or timings records minimal metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
responseBody := `{"result": "ok"}`
|
responseBody := `{"result": "ok"}`
|
||||||
|
|
||||||
@@ -367,10 +378,82 @@ data: [DONE]
|
|||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("infill request extracts timings from last array element", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
// Infill response is an array with timings in the last element
|
||||||
|
responseBody := `[
|
||||||
|
{"content": "first chunk"},
|
||||||
|
{"content": "second chunk"},
|
||||||
|
{"content": "final", "timings": {
|
||||||
|
"prompt_n": 150,
|
||||||
|
"predicted_n": 75,
|
||||||
|
"prompt_per_second": 200.5,
|
||||||
|
"predicted_per_second": 35.5,
|
||||||
|
"prompt_ms": 600.0,
|
||||||
|
"predicted_ms": 1800.0,
|
||||||
|
"cache_n": 30
|
||||||
|
}}
|
||||||
|
]`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/infill", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 150, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||||
|
assert.Equal(t, 30, metrics[0].CachedTokens)
|
||||||
|
assert.Equal(t, 200.5, metrics[0].PromptPerSecond)
|
||||||
|
assert.Equal(t, 35.5, metrics[0].TokensPerSecond)
|
||||||
|
assert.Equal(t, 2400, metrics[0].DurationMs) // 600 + 1800
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("infill request with empty array records minimal metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
responseBody := `[]`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/infill", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,7 +508,7 @@ func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
|||||||
|
|
||||||
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||||
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 1000)
|
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
numGoroutines := 10
|
numGoroutines := 10
|
||||||
@@ -452,7 +535,7 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
|
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 100)
|
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||||
|
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
|
|
||||||
@@ -489,8 +572,29 @@ func TestMetricsMonitor_Concurrent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||||
|
t.Run("keeps wall clock duration when timings underreport request time", func(t *testing.T) {
|
||||||
|
start := time.Now().Add(-5 * time.Second)
|
||||||
|
usage := gjson.Parse(`{"prompt_tokens": 5, "completion_tokens": 1}`)
|
||||||
|
timings := gjson.Parse(`{
|
||||||
|
"prompt_n": 5,
|
||||||
|
"predicted_n": 1,
|
||||||
|
"prompt_per_second": 10.0,
|
||||||
|
"predicted_per_second": 2.0,
|
||||||
|
"prompt_ms": 5.0,
|
||||||
|
"predicted_ms": 15.0
|
||||||
|
}`)
|
||||||
|
|
||||||
|
metrics, err := parseMetrics("test-model", start, usage, timings)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 5, metrics.InputTokens)
|
||||||
|
assert.Equal(t, 1, metrics.OutputTokens)
|
||||||
|
assert.Equal(t, 10.0, metrics.PromptPerSecond)
|
||||||
|
assert.Equal(t, 2.0, metrics.TokensPerSecond)
|
||||||
|
assert.GreaterOrEqual(t, metrics.DurationMs, 5000)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("prefers timings over usage data", func(t *testing.T) {
|
t.Run("prefers timings over usage data", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
// Timings should take precedence over usage
|
// Timings should take precedence over usage
|
||||||
responseBody := `{
|
responseBody := `{
|
||||||
@@ -530,7 +634,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
responseBody := `{
|
responseBody := `{
|
||||||
"timings": {
|
"timings": {
|
||||||
@@ -565,7 +669,7 @@ func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
|||||||
|
|
||||||
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
|
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
|
||||||
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
|
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
// Metrics should be found in the last data line before [DONE]
|
// Metrics should be found in the last data line before [DONE]
|
||||||
responseBody := `data: {"choices":[{"text":"First"}]}
|
responseBody := `data: {"choices":[{"text":"First"}]}
|
||||||
@@ -598,8 +702,8 @@ data: [DONE]
|
|||||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
|
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
responseBody := `data: not json
|
responseBody := `data: not json
|
||||||
|
|
||||||
@@ -619,14 +723,46 @@ data: [DONE]
|
|||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles empty streaming response", func(t *testing.T) {
|
t.Run("v1/responses format with nested response.usage", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10)
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
// v1/responses SSE format: usage is nested under response.usage
|
||||||
|
responseBody := "event: response.completed\n" +
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_abc","object":"response","created_at":1773416985,"status":"completed","model":"test-model","output":[],"usage":{"input_tokens":17,"output_tokens":23,"total_tokens":40}}}` +
|
||||||
|
"\n\n"
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 17, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 23, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
responseBody := ``
|
responseBody := ``
|
||||||
|
|
||||||
@@ -642,17 +778,19 @@ data: [DONE]
|
|||||||
ginCtx, _ := gin.CreateTestContext(rec)
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
// Empty body should not trigger WrapHandler processing
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
metrics := mm.getMetrics()
|
metrics := mm.getMetrics()
|
||||||
assert.Equal(t, 0, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark tests
|
// Benchmark tests
|
||||||
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||||
mm := newMetricsMonitor(testLogger, 1000)
|
mm := newMetricsMonitor(testLogger, 1000, 0)
|
||||||
|
|
||||||
metric := TokenMetrics{
|
metric := TokenMetrics{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -673,7 +811,7 @@ func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||||
// Test performance with a smaller buffer where wrapping occurs more frequently
|
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||||
mm := newMetricsMonitor(testLogger, 100)
|
mm := newMetricsMonitor(testLogger, 100, 0)
|
||||||
|
|
||||||
metric := TokenMetrics{
|
metric := TokenMetrics{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -691,3 +829,392 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
|||||||
mm.addMetrics(metric)
|
mm.addMetrics(metric)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||||
|
t.Run("gzip encoded response", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||||
|
|
||||||
|
// Compress with gzip
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gzWriter := gzip.NewWriter(&buf)
|
||||||
|
gzWriter.Write([]byte(responseBody))
|
||||||
|
gzWriter.Close()
|
||||||
|
compressedBody := buf.Bytes()
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(compressedBody)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("deflate encoded response", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}`
|
||||||
|
|
||||||
|
// Compress with deflate
|
||||||
|
var buf bytes.Buffer
|
||||||
|
flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression)
|
||||||
|
flateWriter.Write([]byte(responseBody))
|
||||||
|
flateWriter.Close()
|
||||||
|
compressedBody := buf.Bytes()
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Content-Encoding", "deflate")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(compressedBody)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 200, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
// Invalid compressed data
|
||||||
|
invalidData := []byte("this is not gzip data")
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(invalidData)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err) // Should not return error, just log warning
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Content-Encoding", "unknown-encoding")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, 300, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 100, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReqRespCapture_CompressedSize(t *testing.T) {
|
||||||
|
t.Run("compressed size is smaller than uncompressed", func(t *testing.T) {
|
||||||
|
capture := ReqRespCapture{
|
||||||
|
ID: 1,
|
||||||
|
ReqPath: "/v1/chat/completions",
|
||||||
|
ReqBody: []byte(`{"model":"test","prompt":"hello world this is a test request body that is reasonably long"}`),
|
||||||
|
RespBody: []byte(`{"id":"resp-123","object":"chat.completion","created":1234567890,"model":"test-model","choices":[{"index":0,"message":{"role":"assistant","content":"This is a test response body with some meaningful content to compress"}},{"index":1,"message":{"role":"user","content":"Another message here"}}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed, uncompressed, err := compressCapture(&capture)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Greater(t, uncompressed, 0)
|
||||||
|
assert.True(t, len(compressed) < uncompressed, "compressed (%d bytes) should be smaller than uncompressed JSON (%d bytes)", len(compressed), uncompressed)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty capture produces compressed output", func(t *testing.T) {
|
||||||
|
capture := ReqRespCapture{}
|
||||||
|
compressed, _, err := compressCapture(&capture)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, compressed)
|
||||||
|
assert.True(t, len(compressed) > 0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_AddCapture(t *testing.T) {
|
||||||
|
t.Run("does nothing when captures disabled", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
capture := ReqRespCapture{
|
||||||
|
ID: 0,
|
||||||
|
ReqBody: []byte("test"),
|
||||||
|
}
|
||||||
|
mm.addCapture(capture)
|
||||||
|
|
||||||
|
// Should not store capture
|
||||||
|
assert.Nil(t, mm.getCaptureByID(0, false))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("adds capture when enabled", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
|
capture := ReqRespCapture{
|
||||||
|
ID: 0,
|
||||||
|
ReqBody: []byte("test request"),
|
||||||
|
RespBody: []byte("test response"),
|
||||||
|
}
|
||||||
|
mm.addCapture(capture)
|
||||||
|
|
||||||
|
retrieved := mm.getCaptureByID(0, true)
|
||||||
|
assert.NotNil(t, retrieved)
|
||||||
|
|
||||||
|
var decoded ReqRespCapture
|
||||||
|
err := json.Unmarshal(retrieved, &decoded)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, decoded.ID)
|
||||||
|
assert.Equal(t, []byte("test request"), decoded.ReqBody)
|
||||||
|
assert.Equal(t, []byte("test response"), decoded.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
// Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
|
||||||
|
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
|
||||||
|
mm.maxCaptureSize = 450
|
||||||
|
|
||||||
|
// Use random-looking data that doesn't compress well with zstd
|
||||||
|
rng := rand.New(rand.NewSource(42))
|
||||||
|
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 80)}
|
||||||
|
rng.Read(capture1.ReqBody)
|
||||||
|
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 80)}
|
||||||
|
rng.Read(capture2.ReqBody)
|
||||||
|
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 80)}
|
||||||
|
rng.Read(capture3.ReqBody)
|
||||||
|
|
||||||
|
mm.addCapture(capture1)
|
||||||
|
mm.addCapture(capture2)
|
||||||
|
// Adding capture3 should evict capture1
|
||||||
|
mm.addCapture(capture3)
|
||||||
|
|
||||||
|
assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted")
|
||||||
|
retrieved := mm.getCaptureByID(1, true)
|
||||||
|
assert.NotNil(t, retrieved, "capture 1 should exist")
|
||||||
|
retrieved = mm.getCaptureByID(2, true)
|
||||||
|
assert.NotNil(t, retrieved, "capture 2 should exist")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skips capture larger than max size", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
mm.maxCaptureSize = 100
|
||||||
|
|
||||||
|
// Use random data that doesn't compress well to create an oversized capture
|
||||||
|
rng := rand.New(rand.NewSource(99))
|
||||||
|
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 300)}
|
||||||
|
rng.Read(largeCapture.ReqBody)
|
||||||
|
mm.addCapture(largeCapture)
|
||||||
|
|
||||||
|
assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
||||||
|
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
|
assert.Nil(t, mm.getCaptureByID(999, false))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns decompressed capture by ID", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
|
capture := ReqRespCapture{
|
||||||
|
ID: 42,
|
||||||
|
ReqBody: []byte("test request"),
|
||||||
|
RespBody: []byte("test response"),
|
||||||
|
}
|
||||||
|
mm.addCapture(capture)
|
||||||
|
|
||||||
|
retrieved := mm.getCaptureByID(42, true)
|
||||||
|
assert.NotNil(t, retrieved)
|
||||||
|
|
||||||
|
var decoded ReqRespCapture
|
||||||
|
err := json.Unmarshal(retrieved, &decoded)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 42, decoded.ID)
|
||||||
|
assert.Equal(t, []byte("test request"), decoded.ReqBody)
|
||||||
|
assert.Equal(t, []byte("test response"), decoded.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns compressed bytes when decompress=false", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
|
capture := ReqRespCapture{
|
||||||
|
ID: 42,
|
||||||
|
ReqBody: []byte("test request body"),
|
||||||
|
RespBody: []byte("test response body"),
|
||||||
|
}
|
||||||
|
mm.addCapture(capture)
|
||||||
|
|
||||||
|
compressed := mm.getCaptureByID(42, false)
|
||||||
|
assert.NotNil(t, compressed)
|
||||||
|
// Compressed data should not be valid JSON (it's zstd-compressed)
|
||||||
|
assert.False(t, gjson.ValidBytes(compressed))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedactHeaders(t *testing.T) {
|
||||||
|
t.Run("redacts sensitive headers", func(t *testing.T) {
|
||||||
|
headers := map[string]string{
|
||||||
|
"Authorization": "Bearer secret-token",
|
||||||
|
"Proxy-Authorization": "Basic creds",
|
||||||
|
"Cookie": "session=abc123",
|
||||||
|
"Set-Cookie": "session=xyz789",
|
||||||
|
"X-Api-Key": "sk-12345",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Custom": "safe-value",
|
||||||
|
}
|
||||||
|
|
||||||
|
redactHeaders(headers)
|
||||||
|
|
||||||
|
assert.Equal(t, "[REDACTED]", headers["Authorization"])
|
||||||
|
assert.Equal(t, "[REDACTED]", headers["Proxy-Authorization"])
|
||||||
|
assert.Equal(t, "[REDACTED]", headers["Cookie"])
|
||||||
|
assert.Equal(t, "[REDACTED]", headers["Set-Cookie"])
|
||||||
|
assert.Equal(t, "[REDACTED]", headers["X-Api-Key"])
|
||||||
|
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||||
|
assert.Equal(t, "safe-value", headers["X-Custom"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles mixed case header names", func(t *testing.T) {
|
||||||
|
headers := map[string]string{
|
||||||
|
"authorization": "Bearer token",
|
||||||
|
"COOKIE": "session=abc",
|
||||||
|
"x-api-key": "key123",
|
||||||
|
}
|
||||||
|
|
||||||
|
redactHeaders(headers)
|
||||||
|
|
||||||
|
assert.Equal(t, "[REDACTED]", headers["authorization"])
|
||||||
|
assert.Equal(t, "[REDACTED]", headers["COOKIE"])
|
||||||
|
assert.Equal(t, "[REDACTED]", headers["x-api-key"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles empty headers", func(t *testing.T) {
|
||||||
|
headers := map[string]string{}
|
||||||
|
redactHeaders(headers)
|
||||||
|
assert.Empty(t, headers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
||||||
|
t.Run("captures request and response when enabled", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
|
requestBody := `{"model": "test", "prompt": "hello"}`
|
||||||
|
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Custom", "header-value")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer secret")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check metric was recorded
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
metricID := metrics[0].ID
|
||||||
|
|
||||||
|
// Check capture was stored with same ID (decompressed)
|
||||||
|
captureData := mm.getCaptureByID(metricID, true)
|
||||||
|
assert.NotNil(t, captureData)
|
||||||
|
|
||||||
|
var capture ReqRespCapture
|
||||||
|
err = json.Unmarshal(captureData, &capture)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, metricID, capture.ID)
|
||||||
|
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||||
|
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||||
|
assert.Equal(t, "/test", capture.ReqPath)
|
||||||
|
assert.Equal(t, "application/json", capture.ReqHeaders["Content-Type"])
|
||||||
|
assert.Equal(t, "[REDACTED]", capture.ReqHeaders["Authorization"])
|
||||||
|
assert.Equal(t, "application/json", capture.RespHeaders["Content-Type"])
|
||||||
|
assert.Equal(t, "header-value", capture.RespHeaders["X-Custom"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not capture when disabled", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 0)
|
||||||
|
|
||||||
|
requestBody := `{"model": "test"}`
|
||||||
|
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(requestBody))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Metrics should still be recorded
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
|
||||||
|
// But no capture
|
||||||
|
capture := mm.getCaptureByID(metrics[0].ID, false)
|
||||||
|
assert.Nil(t, capture)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,143 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type peerProxyMember struct {
|
||||||
|
peerID string
|
||||||
|
reverseProxy *httputil.ReverseProxy
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerProxy struct {
|
||||||
|
peers config.PeerDictionaryConfig
|
||||||
|
proxyMap map[string]*peerProxyMember
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
|
||||||
|
proxyMap := make(map[string]*peerProxyMember)
|
||||||
|
|
||||||
|
// Sort peer IDs for consistent iteration order
|
||||||
|
peerIDs := make([]string, 0, len(peers))
|
||||||
|
for peerID := range peers {
|
||||||
|
peerIDs = append(peerIDs, peerID)
|
||||||
|
}
|
||||||
|
sort.Strings(peerIDs)
|
||||||
|
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
peer := peers[peerID]
|
||||||
|
|
||||||
|
// Create a transport with per-peer timeout configuration
|
||||||
|
peerTransport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||||
|
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||||
|
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||||
|
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create reverse proxy for this peer
|
||||||
|
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||||
|
reverseProxy.Transport = peerTransport
|
||||||
|
|
||||||
|
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||||
|
originalDirector := reverseProxy.Director
|
||||||
|
reverseProxy.Director = func(req *http.Request) {
|
||||||
|
originalDirector(req)
|
||||||
|
// Ensure Host header matches target URL for remote proxying
|
||||||
|
req.Host = req.URL.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
|
resp.Header.Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||||
|
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||||
|
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||||
|
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||||
|
}
|
||||||
|
http.Error(w, errMsg, http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
pp := &peerProxyMember{
|
||||||
|
peerID: peerID,
|
||||||
|
reverseProxy: reverseProxy,
|
||||||
|
apiKey: peer.ApiKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map each model to this peer's proxy
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
if _, found := proxyMap[modelID]; found {
|
||||||
|
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
proxyMap[modelID] = pp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PeerProxy{
|
||||||
|
peers: peers,
|
||||||
|
proxyMap: proxyMap,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||||
|
_, found := p.proxyMap[modelID]
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||||
|
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||||
|
pp, found := p.proxyMap[modelID]
|
||||||
|
if !found {
|
||||||
|
return config.Filters{}
|
||||||
|
}
|
||||||
|
// Get the peer config using the peerID
|
||||||
|
peer, found := p.peers[pp.peerID]
|
||||||
|
if !found {
|
||||||
|
return config.Filters{}
|
||||||
|
}
|
||||||
|
return peer.Filters
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||||
|
return p.peers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||||
|
pp, found := p.proxyMap[model_id]
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject API key if configured for this peer
|
||||||
|
if pp.apiKey != "" {
|
||||||
|
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||||
|
request.Header.Set("x-api-key", pp.apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
pp.reverseProxy.ServeHTTP(writer, request)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,311 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||||
|
peers := config.PeerDictionaryConfig{}
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, pm)
|
||||||
|
assert.Empty(t, pm.proxyMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||||
|
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "test-key",
|
||||||
|
Models: []string{"model-a", "model-b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, pm.proxyMap, 2)
|
||||||
|
assert.True(t, pm.HasPeerModel("model-a"))
|
||||||
|
assert.True(t, pm.HasPeerModel("model-b"))
|
||||||
|
assert.False(t, pm.HasPeerModel("model-c"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||||
|
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL1,
|
||||||
|
Models: []string{"model-a", "model-b"},
|
||||||
|
},
|
||||||
|
"peer2": config.PeerConfig{
|
||||||
|
Proxy: "http://peer2.example.com:8080",
|
||||||
|
ProxyURL: proxyURL2,
|
||||||
|
Models: []string{"model-c", "model-d"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, pm.proxyMap, 4)
|
||||||
|
assert.True(t, pm.HasPeerModel("model-a"))
|
||||||
|
assert.True(t, pm.HasPeerModel("model-b"))
|
||||||
|
assert.True(t, pm.HasPeerModel("model-c"))
|
||||||
|
assert.True(t, pm.HasPeerModel("model-d"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||||
|
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||||
|
// should be mapped, and a warning should be logged
|
||||||
|
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"alpha-peer": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL1,
|
||||||
|
Models: []string{"duplicate-model"},
|
||||||
|
},
|
||||||
|
"beta-peer": config.PeerConfig{
|
||||||
|
Proxy: "http://peer2.example.com:8080",
|
||||||
|
ProxyURL: proxyURL2,
|
||||||
|
Models: []string{"duplicate-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Should only have one entry for the duplicate model
|
||||||
|
assert.Len(t, pm.proxyMap, 1)
|
||||||
|
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasPeerModel(t *testing.T) {
|
||||||
|
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"existing-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||||
|
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||||
|
peers := config.PeerDictionaryConfig{}
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_Success(t *testing.T) {
|
||||||
|
// Create a test server to act as the peer
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("response from peer"))
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "response from peer", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||||
|
// Create a test server that checks for the Authorization header
|
||||||
|
var receivedAuthHeader string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedAuthHeader = r.Header.Get("Authorization")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "secret-api-key",
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||||
|
// Create a test server that checks for the Authorization header
|
||||||
|
var receivedAuthHeader string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedAuthHeader = r.Header.Get("Authorization")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "", // No API key
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, receivedAuthHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||||
|
// Create a test server that checks the Host header
|
||||||
|
var receivedHost string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedHost = r.Host
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// The Host header should be set to the target URL's host
|
||||||
|
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||||
|
// Create a test server that returns SSE content type
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||||
|
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
|
||||||
|
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||||
|
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"test-peer": config.PeerConfig{
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"model1"},
|
||||||
|
Timeouts: config.TimeoutsConfig{
|
||||||
|
Connect: 45,
|
||||||
|
ResponseHeader: 300,
|
||||||
|
TLSHandshake: 15,
|
||||||
|
ExpectContinue: 2,
|
||||||
|
IdleConn: 120,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
peerProxy, err := NewPeerProxy(peers, testLogger)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, peerProxy)
|
||||||
|
assert.True(t, peerProxy.HasPeerModel("model1"))
|
||||||
|
|
||||||
|
// Verify the timeout values are actually applied to the transport
|
||||||
|
member, found := peerProxy.proxyMap["model1"]
|
||||||
|
require.True(t, found, "model1 should exist in proxyMap")
|
||||||
|
assert.NotNil(t, member.reverseProxy)
|
||||||
|
assert.NotNil(t, member.reverseProxy.Transport)
|
||||||
|
|
||||||
|
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||||
|
require.True(t, ok, "Transport should be *http.Transport")
|
||||||
|
|
||||||
|
// Verify all timeout values are correctly applied
|
||||||
|
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
|
||||||
|
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||||
|
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||||
|
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||||
|
// ForceAttemptHTTP2 should be enabled
|
||||||
|
assert.True(t, transport.ForceAttemptHTTP2)
|
||||||
|
}
|
||||||
@@ -77,6 +77,9 @@ type Process struct {
|
|||||||
// used for testing to override the default value
|
// used for testing to override the default value
|
||||||
gracefulStopTimeout time.Duration
|
gracefulStopTimeout time.Duration
|
||||||
|
|
||||||
|
// used for testing to bypass subprocess and reverse proxy
|
||||||
|
testHandler http.Handler
|
||||||
|
|
||||||
// track the number of failed starts
|
// track the number of failed starts
|
||||||
failedStartCount int
|
failedStartCount int
|
||||||
}
|
}
|
||||||
@@ -96,6 +99,24 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
|
|||||||
var reverseProxy *httputil.ReverseProxy
|
var reverseProxy *httputil.ReverseProxy
|
||||||
if proxyURL != nil {
|
if proxyURL != nil {
|
||||||
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||||
|
|
||||||
|
// Create custom transport with configured timeouts
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: time.Duration(config.Timeouts.Connect) * time.Second,
|
||||||
|
KeepAlive: time.Duration(config.Timeouts.KeepAlive) * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: time.Duration(config.Timeouts.TLSHandshake) * time.Second,
|
||||||
|
ResponseHeaderTimeout: time.Duration(config.Timeouts.ResponseHeader) * time.Second,
|
||||||
|
ExpectContinueTimeout: time.Duration(config.Timeouts.ExpectContinue) * time.Second,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: time.Duration(config.Timeouts.IdleConn) * time.Second,
|
||||||
|
}
|
||||||
|
reverseProxy.Transport = transport
|
||||||
|
|
||||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
@@ -218,6 +239,49 @@ func (p *Process) forceState(newState ProcessState) {
|
|||||||
// at any time.
|
// at any time.
|
||||||
func (p *Process) start() error {
|
func (p *Process) start() error {
|
||||||
|
|
||||||
|
// test-only fast path: skip subprocess, health check, and TTL goroutine
|
||||||
|
if p.testHandler != nil {
|
||||||
|
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||||
|
if err == ErrExpectedStateMismatch {
|
||||||
|
if curState == StateStarting {
|
||||||
|
p.waitStarting.Wait()
|
||||||
|
curState = p.CurrentState()
|
||||||
|
if curState == StateReady {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("process was already starting but wound up in state %v", curState)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("process was in state %v when start() was called", curState)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||||
|
}
|
||||||
|
defer p.waitStarting.Done()
|
||||||
|
|
||||||
|
// Mimic the real stop path: cancelUpstream transitions
|
||||||
|
// StateStopping -> StateStopped and closes cmdWaitChan,
|
||||||
|
// matching what waitForCmd does for real subprocesses.
|
||||||
|
ch := make(chan struct{})
|
||||||
|
p.cmdMutex.Lock()
|
||||||
|
p.cancelUpstream = func() {
|
||||||
|
if curState := p.CurrentState(); curState == StateStopping {
|
||||||
|
if _, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
|
p.forceState(StateStopped)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
p.forceState(StateStopped)
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
p.cmdWaitChan = ch
|
||||||
|
p.cmdMutex.Unlock()
|
||||||
|
|
||||||
|
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||||
|
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||||
|
}
|
||||||
|
p.failedStartCount = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if p.config.Proxy == "" {
|
if p.config.Proxy == "" {
|
||||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||||
}
|
}
|
||||||
@@ -368,7 +432,10 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
// Stop will wait for inflight requests to complete before stopping the process.
|
// Stop will wait for inflight requests to complete before stopping the process.
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
|
|
||||||
|
// guard to prevent multiple goroutines from stopping
|
||||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
|
p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,13 +448,17 @@ func (p *Process) Stop() {
|
|||||||
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||||
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||||
func (p *Process) StopImmediately() {
|
func (p *Process) StopImmediately() {
|
||||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
|
||||||
|
// guard to prevent multiple goroutines from stopping the process
|
||||||
|
enterState := p.CurrentState()
|
||||||
|
if !isValidTransition(enterState, StateStopping) {
|
||||||
|
p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
|
p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState)
|
||||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
if curState, err := p.swapState(enterState, StateStopping); err != nil {
|
||||||
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,6 +485,9 @@ func (p *Process) stopCommand() {
|
|||||||
stopStartTime := time.Now()
|
stopStartTime := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||||
|
|
||||||
|
// free the buffer in processLogger so the memory can be recovered
|
||||||
|
p.processLogger.Clear()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
p.cmdMutex.RLock()
|
p.cmdMutex.RLock()
|
||||||
@@ -556,6 +630,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
if !srw.waitForCompletion(completionTimeout) {
|
if !srw.waitForCompletion(completionTimeout) {
|
||||||
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.testHandler != nil {
|
||||||
|
p.testHandler.ServeHTTP(w, r)
|
||||||
|
} else if srw != nil {
|
||||||
p.reverseProxy.ServeHTTP(srw, r)
|
p.reverseProxy.ServeHTTP(srw, r)
|
||||||
} else {
|
} else {
|
||||||
p.reverseProxy.ServeHTTP(w, r)
|
p.reverseProxy.ServeHTTP(w, r)
|
||||||
@@ -646,6 +725,11 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logger returns the logger for this process.
|
||||||
|
func (p *Process) Logger() *LogMonitor {
|
||||||
|
return p.processLogger
|
||||||
|
}
|
||||||
|
|
||||||
var loadingRemarks = []string{
|
var loadingRemarks = []string{
|
||||||
"Still faster than your last standup meeting...",
|
"Still faster than your last standup meeting...",
|
||||||
"Reticulating splines...",
|
"Reticulating splines...",
|
||||||
@@ -864,7 +948,6 @@ func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
|||||||
s.Flush()
|
s.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Flush method
|
|
||||||
func (s *statusResponseWriter) Flush() {
|
func (s *statusResponseWriter) Flush() {
|
||||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@@ -117,12 +118,12 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
expectedMessage := "I_sense_imminent_danger"
|
expectedMessage := "I_sense_imminent_danger"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||||
assert.Equal(t, 0, config.UnloadAfter)
|
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||||
config.UnloadAfter = 3 // seconds
|
conf.UnloadAfter = 3 // seconds
|
||||||
assert.Equal(t, 3, config.UnloadAfter)
|
assert.Equal(t, 3, conf.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
// this should take 4 seconds
|
// this should take 4 seconds
|
||||||
@@ -159,12 +160,12 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
|||||||
t.Skip("skipping test, edit process_test.go to run it ")
|
t.Skip("skipping test, edit process_test.go to run it ")
|
||||||
}
|
}
|
||||||
|
|
||||||
config := getTestSimpleResponderConfig("fast_ttl")
|
conf := getTestSimpleResponderConfig("fast_ttl")
|
||||||
assert.Equal(t, 0, config.UnloadAfter)
|
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||||
config.UnloadAfter = 1 // second
|
conf.UnloadAfter = 1 // second
|
||||||
assert.Equal(t, 1, config.UnloadAfter)
|
assert.Equal(t, 1, conf.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
@@ -395,6 +396,10 @@ func TestProcess_StopImmediately(t *testing.T) {
|
|||||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||||
// the upstream command
|
// the upstream command
|
||||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping slow test")
|
||||||
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("skipping SIGTERM test on Windows ")
|
t.Skip("skipping SIGTERM test on Windows ")
|
||||||
}
|
}
|
||||||
@@ -565,3 +570,39 @@ func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
return w.ResponseRecorder.Write(b)
|
return w.ResponseRecorder.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_CustomTimeouts(t *testing.T) {
|
||||||
|
modelConfig := config.ModelConfig{
|
||||||
|
Cmd: "echo test",
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
Timeouts: config.TimeoutsConfig{
|
||||||
|
Connect: 45,
|
||||||
|
ResponseHeader: 120,
|
||||||
|
TLSHandshake: 15,
|
||||||
|
ExpectContinue: 2,
|
||||||
|
IdleConn: 120,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
debugLogger := NewLogMonitorWriter(io.Discard)
|
||||||
|
process := NewProcess("test-model", 30, modelConfig, debugLogger, debugLogger)
|
||||||
|
|
||||||
|
// Verify the process was created successfully
|
||||||
|
assert.NotNil(t, process)
|
||||||
|
assert.Equal(t, "test-model", process.ID)
|
||||||
|
assert.NotNil(t, process.reverseProxy)
|
||||||
|
assert.NotNil(t, process.reverseProxy.Transport)
|
||||||
|
|
||||||
|
// Verify it's using http.Transport (not some other type)
|
||||||
|
transport, ok := process.reverseProxy.Transport.(*http.Transport)
|
||||||
|
assert.True(t, ok, "Transport should be *http.Transport")
|
||||||
|
assert.NotNil(t, transport)
|
||||||
|
|
||||||
|
// Verify the timeouts are correctly applied
|
||||||
|
assert.Equal(t, 120*time.Second, transport.ResponseHeaderTimeout)
|
||||||
|
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||||
|
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||||
|
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||||
|
assert.True(t, transport.ForceAttemptHTTP2)
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,6 +24,22 @@ type ProcessGroup struct {
|
|||||||
// map of current processes
|
// map of current processes
|
||||||
processes map[string]*Process
|
processes map[string]*Process
|
||||||
lastUsedProcess string
|
lastUsedProcess string
|
||||||
|
|
||||||
|
// inflight tracks fast-path requests (requests for the already-selected
|
||||||
|
// model in a swap group). Fast-path requests Add(1) while holding pg.Lock
|
||||||
|
// and Done() on completion; a concurrent swap request calls inflight.Wait()
|
||||||
|
// under pg.Lock before stopping the current process. Without this tracking,
|
||||||
|
// a fast-path request that has released pg.Lock but has not yet called
|
||||||
|
// Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be
|
||||||
|
// killed mid-request.
|
||||||
|
inflight sync.WaitGroup
|
||||||
|
|
||||||
|
// testDelayFastPath is a test-only hook that, when non-nil, is invoked in
|
||||||
|
// the fast path after pg.Lock is released but before the request is
|
||||||
|
// dispatched to Process.ProxyRequest. Tests use it to park a fast-path
|
||||||
|
// request at the exact race window to deterministically reproduce the
|
||||||
|
// fast-path vs swap race.
|
||||||
|
testDelayFastPath func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||||
@@ -46,7 +62,8 @@ func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, u
|
|||||||
// Create a Process for each member in the group
|
// Create a Process for each member in the group
|
||||||
for _, modelID := range groupConfig.Members {
|
for _, modelID := range groupConfig.Members {
|
||||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||||
|
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||||
pg.processes[modelID] = process
|
pg.processes[modelID] = process
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,6 +80,13 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
|
|||||||
pg.Lock()
|
pg.Lock()
|
||||||
if pg.lastUsedProcess != modelID {
|
if pg.lastUsedProcess != modelID {
|
||||||
|
|
||||||
|
// Wait for in-flight fast-path requests to drain before stopping
|
||||||
|
// the previous process. Without this, a fast-path request that has
|
||||||
|
// released pg.Lock but has not yet incremented
|
||||||
|
// Process.inFlightRequests races with Stop() and can be killed
|
||||||
|
// mid-request.
|
||||||
|
pg.inflight.Wait()
|
||||||
|
|
||||||
// is there something already running?
|
// is there something already running?
|
||||||
if pg.lastUsedProcess != "" {
|
if pg.lastUsedProcess != "" {
|
||||||
pg.processes[pg.lastUsedProcess].Stop()
|
pg.processes[pg.lastUsedProcess].Stop()
|
||||||
@@ -77,7 +101,16 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
|
|||||||
pg.Unlock()
|
pg.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fast path: register this request in inflight before releasing
|
||||||
|
// pg.Lock so a concurrent swap will wait for it to complete.
|
||||||
|
pg.inflight.Add(1)
|
||||||
|
defer pg.inflight.Done()
|
||||||
pg.Unlock()
|
pg.Unlock()
|
||||||
|
|
||||||
|
if pg.testDelayFastPath != nil {
|
||||||
|
pg.testDelayFastPath()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pg.processes[modelID].ProxyRequest(writer, request)
|
pg.processes[modelID].ProxyRequest(writer, request)
|
||||||
@@ -88,6 +121,13 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
|||||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||||
|
if pg.HasMember(modelName) {
|
||||||
|
return pg.processes[modelName], true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
|
|
||||||
@@ -115,6 +155,10 @@ func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
|||||||
pg.Lock()
|
pg.Lock()
|
||||||
defer pg.Unlock()
|
defer pg.Unlock()
|
||||||
|
|
||||||
|
if strategy != StopImmediately {
|
||||||
|
pg.inflight.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
if len(pg.processes) == 0 {
|
if len(pg.processes) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,14 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
@@ -49,6 +52,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||||
// and multiple requests are made in parallel, only one process is running at a time.
|
// and multiple requests are made in parallel, only one process is running at a time.
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping slow test")
|
||||||
|
}
|
||||||
|
|
||||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]config.ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
@@ -91,6 +98,229 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap
|
||||||
|
// request cannot stop the current process while a fast-path request (for the
|
||||||
|
// already-selected model) is in flight. Without ProcessGroup-level inflight
|
||||||
|
// tracking, a fast-path request that has released pg.Lock but has not yet
|
||||||
|
// incremented Process.inFlightRequests races with Stop()'s Wait() and the
|
||||||
|
// process is killed mid-request.
|
||||||
|
func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||||
|
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Members: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||||
|
defer pg.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
// Bypass real subprocesses so the test is fast and deterministic.
|
||||||
|
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||||
|
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||||
|
|
||||||
|
// Prime: run a request through model1 via the swap path so that
|
||||||
|
// lastUsedProcess == "model1" and subsequent model1 requests take the
|
||||||
|
// fast path.
|
||||||
|
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
primeW := httptest.NewRecorder()
|
||||||
|
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||||
|
require.Equal(t, http.StatusOK, primeW.Code)
|
||||||
|
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||||
|
require.Equal(t, StateStopped, pg.processes["model2"].CurrentState())
|
||||||
|
|
||||||
|
// Fast-path hook: signal arrival at the race window, then wait for
|
||||||
|
// release. This parks R2 deterministically at the point where pg.Lock
|
||||||
|
// has been released but Process.inFlightRequests has not yet been
|
||||||
|
// incremented — the exact window the race exploits.
|
||||||
|
r2Reached := make(chan struct{})
|
||||||
|
r2Release := make(chan struct{})
|
||||||
|
pg.testDelayFastPath = func() {
|
||||||
|
close(r2Reached)
|
||||||
|
<-r2Release
|
||||||
|
}
|
||||||
|
|
||||||
|
// R2: fast-path request for model1. Will pause at the test hook.
|
||||||
|
r2Done := make(chan struct{})
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r2Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Deterministically wait for R2 to reach the race window.
|
||||||
|
<-r2Reached
|
||||||
|
|
||||||
|
// R3: swap request for model2. Must wait for R2 to finish before touching
|
||||||
|
// model1, otherwise model1 gets killed mid-request.
|
||||||
|
r3Done := make(chan struct{})
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r3Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, pg.ProxyRequest("model2", w3, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Spin until R3 has acquired pg.Lock and entered the swap critical
|
||||||
|
// section. In the fixed code, R3 then blocks on pg.inflight.Wait() while
|
||||||
|
// still holding the lock, so TryLock keeps failing.
|
||||||
|
for pg.TryLock() {
|
||||||
|
pg.Unlock()
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||||
|
// state. In the fixed code, R3 is blocked on pg.inflight.Wait() and
|
||||||
|
// nothing changes, so we wait the full window. In the buggy code, R3
|
||||||
|
// will Stop() model1 and start serving via model2 within microseconds —
|
||||||
|
// we exit early once the mutation is observable.
|
||||||
|
deadline := time.Now().Add(100 * time.Millisecond)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if pg.processes["model1"].CurrentState() != StateReady ||
|
||||||
|
pg.processes["model2"].CurrentState() != StateStopped {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
done := false
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
done = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
t.Fatal("swap completed while fast-path request was still in flight — race not prevented")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||||
|
"model1 must stay Ready while a fast-path request is in flight")
|
||||||
|
assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(),
|
||||||
|
"model2 must not be started until R2 finishes and model1 is swapped out")
|
||||||
|
|
||||||
|
// Release R2 and let both requests finish.
|
||||||
|
close(r2Release)
|
||||||
|
<-r2Done
|
||||||
|
<-r3Done
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w2.Code)
|
||||||
|
assert.Contains(t, w2.Body.String(), "model1")
|
||||||
|
assert.Equal(t, http.StatusOK, w3.Code)
|
||||||
|
assert.Contains(t, w3.Body.String(), "model2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses
|
||||||
|
// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a
|
||||||
|
// process while a fast-path ProxyRequest is in the [pg.Unlock,
|
||||||
|
// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in
|
||||||
|
// StopProcesses, the external caller bypasses the inflight guard and kills the
|
||||||
|
// process mid-request.
|
||||||
|
func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) {
|
||||||
|
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Members: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||||
|
defer pg.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||||
|
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||||
|
|
||||||
|
// Prime: model1 is active so subsequent model1 requests take the fast path.
|
||||||
|
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
primeW := httptest.NewRecorder()
|
||||||
|
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||||
|
require.Equal(t, http.StatusOK, primeW.Code)
|
||||||
|
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||||
|
|
||||||
|
// Park a fast-path request at the race window.
|
||||||
|
r2Reached := make(chan struct{})
|
||||||
|
r2Release := make(chan struct{})
|
||||||
|
pg.testDelayFastPath = func() {
|
||||||
|
close(r2Reached)
|
||||||
|
<-r2Release
|
||||||
|
}
|
||||||
|
|
||||||
|
r2Done := make(chan struct{})
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r2Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-r2Reached
|
||||||
|
|
||||||
|
// Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping
|
||||||
|
// the group while a fast-path request is in flight.
|
||||||
|
r3Done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(r3Done)
|
||||||
|
pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Spin until StopProcesses has acquired pg.Lock.
|
||||||
|
for pg.TryLock() {
|
||||||
|
pg.Unlock()
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait()
|
||||||
|
// and model1 stays Ready. In the buggy code it proceeds immediately and
|
||||||
|
// kills model1.
|
||||||
|
deadline := time.Now().Add(100 * time.Millisecond)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if pg.processes["model1"].CurrentState() != StateReady {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
goto done
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
done:
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||||
|
"model1 must stay Ready while a fast-path request is in flight")
|
||||||
|
|
||||||
|
close(r2Release)
|
||||||
|
<-r2Done
|
||||||
|
<-r3Done
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w2.Code)
|
||||||
|
assert.Contains(t, w2.Body.String(), "model1")
|
||||||
|
}
|
||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
@@ -27,6 +28,40 @@ const (
|
|||||||
|
|
||||||
type proxyCtxKey string
|
type proxyCtxKey string
|
||||||
|
|
||||||
|
type InflightCounter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
total int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInflightCounter() *InflightCounter {
|
||||||
|
return &InflightCounter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ic *InflightCounter) Current() int {
|
||||||
|
ic.mu.Lock()
|
||||||
|
total := ic.total
|
||||||
|
ic.mu.Unlock()
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ic *InflightCounter) Increment() int {
|
||||||
|
ic.mu.Lock()
|
||||||
|
ic.total++
|
||||||
|
total := ic.total
|
||||||
|
ic.mu.Unlock()
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ic *InflightCounter) Decrement() int {
|
||||||
|
ic.mu.Lock()
|
||||||
|
if ic.total > 0 {
|
||||||
|
ic.total--
|
||||||
|
}
|
||||||
|
total := ic.total
|
||||||
|
ic.mu.Unlock()
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
@@ -42,6 +77,11 @@ type ProxyManager struct {
|
|||||||
|
|
||||||
processGroups map[string]*ProcessGroup
|
processGroups map[string]*ProcessGroup
|
||||||
|
|
||||||
|
// matrix-based swap (mutually exclusive with processGroups)
|
||||||
|
matrix *Matrix
|
||||||
|
|
||||||
|
inFlightCounter *InflightCounter
|
||||||
|
|
||||||
// shutdown signaling
|
// shutdown signaling
|
||||||
shutdownCtx context.Context
|
shutdownCtx context.Context
|
||||||
shutdownCancel context.CancelFunc
|
shutdownCancel context.CancelFunc
|
||||||
@@ -50,19 +90,42 @@ type ProxyManager struct {
|
|||||||
buildDate string
|
buildDate string
|
||||||
commit string
|
commit string
|
||||||
version string
|
version string
|
||||||
|
|
||||||
|
// peer proxy see: #296, #433
|
||||||
|
peerProxy *PeerProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config config.Config) *ProxyManager {
|
func New(proxyConfig config.Config) *ProxyManager {
|
||||||
// set up loggers
|
// set up loggers
|
||||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
|
||||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
|
||||||
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
|
||||||
|
|
||||||
if config.LogRequests {
|
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
|
||||||
|
switch proxyConfig.LogToStdout {
|
||||||
|
case config.LogToStdoutNone:
|
||||||
|
muxLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
case config.LogToStdoutBoth:
|
||||||
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||||
|
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||||
|
case config.LogToStdoutUpstream:
|
||||||
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||||
|
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
default:
|
||||||
|
// same as config.LogToStdoutProxy
|
||||||
|
// helpful because some old tests create a config.Config directly and it
|
||||||
|
// may not have LogToStdout set explicitly
|
||||||
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if proxyConfig.LogRequests {
|
||||||
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
|
||||||
case "debug":
|
case "debug":
|
||||||
proxyLogger.SetLogLevel(LevelDebug)
|
proxyLogger.SetLogLevel(LevelDebug)
|
||||||
upstreamLogger.SetLogLevel(LevelDebug)
|
upstreamLogger.SetLogLevel(LevelDebug)
|
||||||
@@ -99,7 +162,7 @@ func New(config config.Config) *ProxyManager {
|
|||||||
"stampnano": time.StampNano,
|
"stampnano": time.StampNano,
|
||||||
}
|
}
|
||||||
|
|
||||||
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(config.LogTimeFormat))]; ok {
|
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
|
||||||
proxyLogger.SetLogTimeFormat(timeFormat)
|
proxyLogger.SetLogTimeFormat(timeFormat)
|
||||||
upstreamLogger.SetLogTimeFormat(timeFormat)
|
upstreamLogger.SetLogTimeFormat(timeFormat)
|
||||||
}
|
}
|
||||||
@@ -107,61 +170,93 @@ func New(config config.Config) *ProxyManager {
|
|||||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
var maxMetrics int
|
var maxMetrics int
|
||||||
if config.MetricsMaxInMemory <= 0 {
|
if proxyConfig.MetricsMaxInMemory <= 0 {
|
||||||
maxMetrics = 1000 // Default fallback
|
maxMetrics = 1000 // Default fallback
|
||||||
} else {
|
} else {
|
||||||
maxMetrics = config.MetricsMaxInMemory
|
maxMetrics = proxyConfig.MetricsMaxInMemory
|
||||||
|
}
|
||||||
|
|
||||||
|
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
||||||
|
if err != nil {
|
||||||
|
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
||||||
|
peerProxy = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: proxyConfig,
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
|
|
||||||
proxyLogger: proxyLogger,
|
proxyLogger: proxyLogger,
|
||||||
muxLogger: stdoutLogger,
|
muxLogger: muxLogger,
|
||||||
upstreamLogger: upstreamLogger,
|
upstreamLogger: upstreamLogger,
|
||||||
|
|
||||||
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
|
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer),
|
||||||
|
|
||||||
processGroups: make(map[string]*ProcessGroup),
|
processGroups: make(map[string]*ProcessGroup),
|
||||||
|
|
||||||
|
inFlightCounter: newInflightCounter(),
|
||||||
|
|
||||||
shutdownCtx: shutdownCtx,
|
shutdownCtx: shutdownCtx,
|
||||||
shutdownCancel: shutdownCancel,
|
shutdownCancel: shutdownCancel,
|
||||||
|
|
||||||
buildDate: "unknown",
|
buildDate: "unknown",
|
||||||
commit: "abcd1234",
|
commit: "abcd1234",
|
||||||
version: "0",
|
version: "0",
|
||||||
|
|
||||||
|
peerProxy: peerProxy,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the process groups
|
// create either matrix or process groups (mutually exclusive)
|
||||||
for groupID := range config.Groups {
|
if proxyConfig.Matrix != nil {
|
||||||
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
pm.matrix = NewMatrix(proxyConfig, proxyLogger, upstreamLogger)
|
||||||
pm.processGroups[groupID] = processGroup
|
} else {
|
||||||
|
for groupID := range proxyConfig.Groups {
|
||||||
|
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
||||||
|
pm.processGroups[groupID] = processGroup
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.setupGinEngine()
|
pm.setupGinEngine()
|
||||||
|
|
||||||
// run any startup hooks
|
// run any startup hooks
|
||||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
|
||||||
// do it in the background, don't block startup -- not sure if good idea yet
|
// do it in the background, don't block startup -- not sure if good idea yet
|
||||||
go func() {
|
go func() {
|
||||||
discardWriter := &DiscardWriter{}
|
discardWriter := &DiscardWriter{}
|
||||||
for _, realModelName := range config.Hooks.OnStartup.Preload {
|
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||||
proxyLogger.Infof("Preloading model: %s", realModelName)
|
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
||||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
|
||||||
|
|
||||||
if err != nil {
|
if !ok {
|
||||||
|
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||||
|
|
||||||
|
var preloadErr error
|
||||||
|
req, _ := http.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
if pm.matrix != nil {
|
||||||
|
preloadErr = pm.matrix.ProxyRequest(modelID, discardWriter, req)
|
||||||
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
preloadErr = err
|
||||||
|
} else {
|
||||||
|
preloadErr = processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if preloadErr != nil {
|
||||||
event.Emit(ModelPreloadedEvent{
|
event.Emit(ModelPreloadedEvent{
|
||||||
ModelName: realModelName,
|
ModelName: modelID,
|
||||||
Success: false,
|
Success: false,
|
||||||
})
|
})
|
||||||
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, preloadErr)
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
req, _ := http.NewRequest("GET", "/", nil)
|
|
||||||
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
|
||||||
event.Emit(ModelPreloadedEvent{
|
event.Emit(ModelPreloadedEvent{
|
||||||
ModelName: realModelName,
|
ModelName: modelID,
|
||||||
Success: true,
|
Success: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -236,37 +331,50 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler)
|
// Protected routes use pm.apiKeyAuth() middleware
|
||||||
|
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||||
pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
// Support anthropic count_tokens API (Also added in the above PR)
|
||||||
|
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// Support embeddings and reranking
|
// Support embeddings and reranking
|
||||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// llama-server's /reranking endpoint + aliases
|
// llama-server's /reranking endpoint + aliases
|
||||||
pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// llama-server's /infill endpoint for code infilling
|
// llama-server's /infill endpoint for code infilling
|
||||||
pm.ginEngine.POST("/infill", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// llama-server's /completion endpoint
|
// llama-server's /completion endpoint
|
||||||
pm.ginEngine.POST("/completion", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||||
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||||
|
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyOAIPostFormHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
// sd.cpp /sdapi/v1 endpoints
|
||||||
|
pm.ginEngine.POST("/sdapi/v1/txt2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/sdapi/v1/img2img", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.GET("/sdapi/v1/loras", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyGETModelHandler)
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||||
|
|
||||||
// in proxymanager_loghandlers.go
|
// in proxymanager_loghandlers.go
|
||||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User Interface Endpoints
|
* User Interface Endpoints
|
||||||
@@ -278,9 +386,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||||
c.Redirect(http.StatusFound, "/ui/models")
|
c.Redirect(http.StatusFound, "/ui/models")
|
||||||
})
|
})
|
||||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.trackInflight(), pm.proxyToUpstream)
|
||||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
||||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
||||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
})
|
})
|
||||||
@@ -302,25 +410,35 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
||||||
} else {
|
} else {
|
||||||
|
// Serve files with compression support under /ui/*
|
||||||
|
// This handler checks for pre-compressed .br and .gz files
|
||||||
|
pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) {
|
||||||
|
filepath := strings.TrimPrefix(c.Param("filepath"), "/")
|
||||||
|
// Default to index.html for directory-like paths
|
||||||
|
if filepath == "" {
|
||||||
|
filepath = "index.html"
|
||||||
|
}
|
||||||
|
|
||||||
// serve files that exist under /ui/*
|
ServeCompressedFile(reactFS, c.Writer, c.Request, filepath)
|
||||||
pm.ginEngine.StaticFS("/ui", reactFS)
|
})
|
||||||
|
|
||||||
// server SPA for UI under /ui/*
|
// Serve SPA for UI under /ui/* - fallback to index.html for client-side routing
|
||||||
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
||||||
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
||||||
c.AbortWithStatus(http.StatusNotFound)
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := reactFS.Open("index.html")
|
// Check if this looks like a file request (has extension)
|
||||||
if err != nil {
|
path := c.Request.URL.Path
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") {
|
||||||
|
// This was likely a file request that wasn't found
|
||||||
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer file.Close()
|
|
||||||
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
|
|
||||||
|
|
||||||
|
// Serve index.html for SPA routing
|
||||||
|
ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -332,6 +450,14 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
gin.DisableConsoleColor()
|
gin.DisableConsoleColor()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) trackInflight() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Increment()})
|
||||||
|
defer event.Emit(InFlightRequestsEvent{Total: pm.inFlightCounter.Decrement()})
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ServeHTTP implements http.Handler interface
|
// ServeHTTP implements http.Handler interface
|
||||||
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
pm.ginEngine.ServeHTTP(w, r)
|
pm.ginEngine.ServeHTTP(w, r)
|
||||||
@@ -345,6 +471,11 @@ func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
|||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
|
if pm.matrix != nil {
|
||||||
|
pm.matrix.StopProcesses(strategy)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// stop Processes in parallel
|
// stop Processes in parallel
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, processGroup := range pm.processGroups {
|
for _, processGroup := range pm.processGroups {
|
||||||
@@ -365,6 +496,12 @@ func (pm *ProxyManager) Shutdown() {
|
|||||||
|
|
||||||
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
||||||
|
|
||||||
|
if pm.matrix != nil {
|
||||||
|
pm.matrix.Shutdown()
|
||||||
|
pm.shutdownCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
// Send shutdown signal to all process in groups
|
// Send shutdown signal to all process in groups
|
||||||
for _, processGroup := range pm.processGroups {
|
for _, processGroup := range pm.processGroups {
|
||||||
@@ -378,16 +515,10 @@ func (pm *ProxyManager) Shutdown() {
|
|||||||
pm.shutdownCancel()
|
pm.shutdownCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
||||||
// de-alias the real model name and get a real one
|
|
||||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
|
||||||
if !found {
|
|
||||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
processGroup := pm.findGroupByModelName(realModelName)
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
if processGroup == nil {
|
if processGroup == nil {
|
||||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if processGroup.exclusive {
|
if processGroup.exclusive {
|
||||||
@@ -399,54 +530,71 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return processGroup, realModelName, nil
|
return processGroup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := make([]gin.H, 0, len(pm.config.Models))
|
data := make([]gin.H, 0, len(pm.config.Models))
|
||||||
createdTime := time.Now().Unix()
|
createdTime := time.Now().Unix()
|
||||||
|
|
||||||
|
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
||||||
|
record := gin.H{
|
||||||
|
"id": modelId,
|
||||||
|
"object": "model",
|
||||||
|
"created": createdTime,
|
||||||
|
"owned_by": "llama-swap",
|
||||||
|
}
|
||||||
|
|
||||||
|
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||||
|
record["name"] = name
|
||||||
|
}
|
||||||
|
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||||
|
record["description"] = desc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add metadata if present
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
record["meta"] = gin.H{
|
||||||
|
"llamaswap": modelConfig.Metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return record
|
||||||
|
}
|
||||||
|
|
||||||
for id, modelConfig := range pm.config.Models {
|
for id, modelConfig := range pm.config.Models {
|
||||||
if modelConfig.Unlisted {
|
if modelConfig.Unlisted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
newRecord := func(modelId string) gin.H {
|
data = append(data, newRecord(id, modelConfig))
|
||||||
record := gin.H{
|
|
||||||
"id": modelId,
|
|
||||||
"object": "model",
|
|
||||||
"created": createdTime,
|
|
||||||
"owned_by": "llama-swap",
|
|
||||||
}
|
|
||||||
|
|
||||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
|
||||||
record["name"] = name
|
|
||||||
}
|
|
||||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
|
||||||
record["description"] = desc
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add metadata if present
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
record["meta"] = gin.H{
|
|
||||||
"llamaswap": modelConfig.Metadata,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return record
|
|
||||||
}
|
|
||||||
|
|
||||||
data = append(data, newRecord(id))
|
|
||||||
|
|
||||||
// Include aliases
|
// Include aliases
|
||||||
if pm.config.IncludeAliasesInList {
|
if pm.config.IncludeAliasesInList {
|
||||||
for _, alias := range modelConfig.Aliases {
|
for _, alias := range modelConfig.Aliases {
|
||||||
if alias := strings.TrimSpace(alias); alias != "" {
|
if alias := strings.TrimSpace(alias); alias != "" {
|
||||||
data = append(data, newRecord(alias))
|
data = append(data, newRecord(alias, modelConfig))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pm.peerProxy != nil {
|
||||||
|
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||||
|
// add peer models
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
// Skip unlisted models if not showing them
|
||||||
|
record := newRecord(modelID, config.ModelConfig{
|
||||||
|
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"peerID": peerID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
data = append(data, record)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Sort by the "id" key
|
// Sort by the "id" key
|
||||||
sort.Slice(data, func(i, j int) bool {
|
sort.Slice(data, func(i, j int) bool {
|
||||||
si, _ := data[i]["id"].(string)
|
si, _ := data[i]["id"].(string)
|
||||||
@@ -466,82 +614,87 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
// findModelInPath searches for a valid model name in a path with slashes.
|
||||||
upstreamPath := c.Param("upstreamPath")
|
// It iteratively builds up path segments until it finds a matching model.
|
||||||
|
// Returns: (searchModelName, realModelName, remainingPath, found)
|
||||||
// split the upstream path by / and search for the model name
|
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
|
||||||
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
|
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
|
||||||
if len(parts) == 0 {
|
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
modelFound := false
|
|
||||||
searchModelName := ""
|
searchModelName := ""
|
||||||
var modelName, remainingPath string
|
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
if parts[i] == "" {
|
if part == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if searchModelName == "" {
|
if searchModelName == "" {
|
||||||
searchModelName = part
|
searchModelName = part
|
||||||
} else {
|
} else {
|
||||||
searchModelName = searchModelName + "/" + parts[i]
|
searchModelName = searchModelName + "/" + part
|
||||||
}
|
}
|
||||||
|
|
||||||
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
||||||
modelName = real
|
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
|
||||||
modelFound = true
|
|
||||||
|
|
||||||
// Check if this is exactly a model name with no additional path
|
|
||||||
// and doesn't end with a trailing slash
|
|
||||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
|
||||||
// Build new URL with query parameters preserved
|
|
||||||
newPath := "/upstream/" + searchModelName + "/"
|
|
||||||
if c.Request.URL.RawQuery != "" {
|
|
||||||
newPath += "?" + c.Request.URL.RawQuery
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use 308 for non-GET/HEAD requests to preserve method
|
|
||||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
|
||||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
|
||||||
} else {
|
|
||||||
c.Redirect(http.StatusPermanentRedirect, newPath)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||||
|
upstreamPath := c.Param("upstreamPath")
|
||||||
|
|
||||||
|
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||||
|
|
||||||
if !modelFound {
|
if !modelFound {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
||||||
if err != nil {
|
// This ensures relative URLs in upstream responses resolve correctly and
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
||||||
|
// HTTP method (301 would downgrade to GET).
|
||||||
|
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||||
|
newPath := "/upstream/" + searchModelName + "/"
|
||||||
|
if c.Request.URL.RawQuery != "" {
|
||||||
|
newPath += "?" + c.Request.URL.RawQuery
|
||||||
|
}
|
||||||
|
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||||
|
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||||
|
} else {
|
||||||
|
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var handler func(string, http.ResponseWriter, *http.Request) error
|
||||||
|
if pm.matrix != nil {
|
||||||
|
handler = pm.matrix.ProxyRequest
|
||||||
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler = processGroup.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
originalPath := c.Request.URL.Path
|
originalPath := c.Request.URL.Path
|
||||||
c.Request.URL.Path = remainingPath
|
c.Request.URL.Path = remainingPath
|
||||||
|
|
||||||
// attempt to record metrics if it is a POST request
|
// attempt to record metrics if it is a POST request
|
||||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, handler); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
if err := handler(modelID, c.Writer, c.Request); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -560,41 +713,107 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
// Look for a matching local model first
|
||||||
if !found {
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
modelID, found := pm.config.RealModelName(requestedModel)
|
||||||
if err != nil {
|
if found {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
var localHandler func(string, http.ResponseWriter, *http.Request) error
|
||||||
return
|
if pm.matrix != nil {
|
||||||
}
|
localHandler = pm.matrix.ProxyRequest
|
||||||
|
} else {
|
||||||
// issue #69 allow custom model names to be sent to upstream
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
useModelName := pm.config.Models[realModelName].UseModelName
|
|
||||||
if useModelName != "" {
|
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// issue #174 strip parameters from the JSON body
|
|
||||||
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
|
||||||
if err != nil { // just log it and continue
|
|
||||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
|
||||||
} else {
|
|
||||||
for _, param := range stripParams {
|
|
||||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
|
||||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
localHandler = processGroup.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
|
useModelName := pm.config.Models[modelID].UseModelName
|
||||||
|
if useModelName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// issue #174 strip parameters from the JSON body
|
||||||
|
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||||
|
if err != nil { // just log it and continue
|
||||||
|
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||||
|
} else {
|
||||||
|
for _, param := range stripParams {
|
||||||
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||||
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #453 set/override parameters in the JSON body
|
||||||
|
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||||
|
for _, key := range setParamKeys {
|
||||||
|
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setParamsByID: set params based on the requested model ID (runs after setParams, can override it)
|
||||||
|
setParamsByIDParams, setParamsByIDKeys := pm.config.Models[modelID].Filters.SanitizedSetParamsByID(requestedModel)
|
||||||
|
for _, key := range setParamsByIDKeys {
|
||||||
|
pm.proxyLogger.Debugf("<%s> setting param by id: %s", requestedModel, key)
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParamsByIDParams[key])
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
|
nextHandler = localHandler
|
||||||
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
|
modelID = requestedModel
|
||||||
|
|
||||||
|
// issue #453 apply filters for peer requests
|
||||||
|
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||||
|
|
||||||
|
// Apply stripParams - remove specified parameters from request
|
||||||
|
stripParams := peerFilters.SanitizedStripParams()
|
||||||
|
for _, param := range stripParams {
|
||||||
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||||
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply setParams - set/override specified parameters in request
|
||||||
|
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||||
|
for _, key := range setParamKeys {
|
||||||
|
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nextHandler = pm.peerProxy.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextHandler == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
@@ -607,19 +826,19 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
|||||||
// issue #366 extract values that downstream handlers may need
|
// issue #366 extract values that downstream handlers may need
|
||||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
|
||||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -639,9 +858,33 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
// Look for a matching local model first, then check peers
|
||||||
if err != nil {
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
var useModelName string
|
||||||
|
|
||||||
|
modelID, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if found {
|
||||||
|
if pm.matrix != nil {
|
||||||
|
nextHandler = pm.matrix.ProxyRequest
|
||||||
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextHandler = processGroup.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
useModelName = pm.config.Models[modelID].UseModelName
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
|
modelID = requestedModel
|
||||||
|
nextHandler = pm.peerProxy.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextHandler == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -657,8 +900,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
// If this is the model field and we have a profile, use just the model name
|
// If this is the model field and we have a profile, use just the model name
|
||||||
if key == "model" {
|
if key == "model" {
|
||||||
// # issue #69 allow custom model names to be sent to upstream
|
// # issue #69 allow custom model names to be sent to upstream
|
||||||
useModelName := pm.config.Models[realModelName].UseModelName
|
|
||||||
|
|
||||||
if useModelName != "" {
|
if useModelName != "" {
|
||||||
fieldValue = useModelName
|
fieldValue = useModelName
|
||||||
} else {
|
} else {
|
||||||
@@ -728,9 +969,50 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||||
|
|
||||||
// Use the modified request for proxying
|
// Use the modified request for proxying
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
||||||
|
requestedModel := c.Query("model")
|
||||||
|
if requestedModel == "" {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||||
|
var modelID string
|
||||||
|
|
||||||
|
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
||||||
|
modelID = realModelID
|
||||||
|
if pm.matrix != nil {
|
||||||
|
nextHandler = pm.matrix.ProxyRequest
|
||||||
|
} else {
|
||||||
|
processGroup, err := pm.swapProcessGroup(realModelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextHandler = processGroup.ProxyRequest
|
||||||
|
}
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
|
modelID = requestedModel
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
|
nextHandler = pm.peerProxy.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextHandler == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -745,6 +1027,67 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// apiKeyAuth returns a middleware that validates API keys if configured.
|
||||||
|
// Returns a pass-through handler if no API keys are configured.
|
||||||
|
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
||||||
|
if len(pm.config.RequiredAPIKeys) == 0 {
|
||||||
|
return func(c *gin.Context) { c.Next() }
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
xApiKey := c.GetHeader("x-api-key")
|
||||||
|
|
||||||
|
var bearerKey string
|
||||||
|
var basicKey string
|
||||||
|
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||||
|
if strings.HasPrefix(auth, "Bearer ") {
|
||||||
|
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
} else if strings.HasPrefix(auth, "Basic ") {
|
||||||
|
// Basic Auth: base64(username:password), password is the API key
|
||||||
|
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||||
|
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||||
|
parts := strings.SplitN(string(decoded), ":", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
basicKey = parts[1] // password is the API key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use first key found: Basic, then Bearer, then x-api-key
|
||||||
|
var providedKey string
|
||||||
|
if basicKey != "" {
|
||||||
|
providedKey = basicKey
|
||||||
|
} else if bearerKey != "" {
|
||||||
|
providedKey = bearerKey
|
||||||
|
} else {
|
||||||
|
providedKey = xApiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate key
|
||||||
|
valid := false
|
||||||
|
for _, key := range pm.config.RequiredAPIKeys {
|
||||||
|
if providedKey == key {
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||||
|
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip auth headers to prevent leakage to upstream
|
||||||
|
c.Request.Header.Del("Authorization")
|
||||||
|
c.Request.Header.Del("x-api-key")
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
pm.StopProcesses(StopImmediately)
|
pm.StopProcesses(StopImmediately)
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
@@ -754,15 +1097,36 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|||||||
context.Header("Content-Type", "application/json")
|
context.Header("Content-Type", "application/json")
|
||||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||||
|
|
||||||
for _, processGroup := range pm.processGroups {
|
if pm.matrix != nil {
|
||||||
for _, process := range processGroup.processes {
|
for _, modelID := range pm.matrix.RunningModels() {
|
||||||
if process.CurrentState() == StateReady {
|
if process, ok := pm.matrix.GetProcess(modelID); ok {
|
||||||
runningProcesses = append(runningProcesses, gin.H{
|
runningProcesses = append(runningProcesses, gin.H{
|
||||||
"model": process.ID,
|
"model": process.ID,
|
||||||
"state": process.state,
|
"state": process.state,
|
||||||
|
"cmd": process.config.Cmd,
|
||||||
|
"proxy": process.config.Proxy,
|
||||||
|
"ttl": process.config.UnloadAfter,
|
||||||
|
"name": process.config.Name,
|
||||||
|
"description": process.config.Description,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for _, processGroup := range pm.processGroups {
|
||||||
|
for _, process := range processGroup.processes {
|
||||||
|
if process.CurrentState() == StateReady {
|
||||||
|
runningProcesses = append(runningProcesses, gin.H{
|
||||||
|
"model": process.ID,
|
||||||
|
"state": process.state,
|
||||||
|
"cmd": process.config.Cmd,
|
||||||
|
"proxy": process.config.Proxy,
|
||||||
|
"ttl": process.config.UnloadAfter,
|
||||||
|
"name": process.config.Name,
|
||||||
|
"description": process.config.Description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put the results under the `running` key.
|
// Put the results under the `running` key.
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -13,22 +14,26 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
State string `json:"state"`
|
State string `json:"state"`
|
||||||
Unlisted bool `json:"unlisted"`
|
Unlisted bool `json:"unlisted"`
|
||||||
|
PeerID string `json:"peerID"`
|
||||||
|
Aliases []string `json:"aliases,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func addApiHandlers(pm *ProxyManager) {
|
func addApiHandlers(pm *ProxyManager) {
|
||||||
// Add API endpoints for React to consume
|
// Add API endpoints for React to consume
|
||||||
apiGroup := pm.ginEngine.Group("/api")
|
// Protected with API key authentication
|
||||||
|
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||||
{
|
{
|
||||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||||
apiGroup.GET("/events", pm.apiSendEvents)
|
apiGroup.GET("/events", pm.apiSendEvents)
|
||||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||||
apiGroup.GET("/version", pm.apiGetVersion)
|
apiGroup.GET("/version", pm.apiGetVersion)
|
||||||
|
apiGroup.GET("/captures/:id", pm.apiGetCapture)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,27 +55,28 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
|||||||
// Iterate over sorted keys
|
// Iterate over sorted keys
|
||||||
for _, modelID := range modelIDs {
|
for _, modelID := range modelIDs {
|
||||||
// Get process state
|
// Get process state
|
||||||
processGroup := pm.findGroupByModelName(modelID)
|
|
||||||
state := "unknown"
|
state := "unknown"
|
||||||
if processGroup != nil {
|
var process *Process
|
||||||
process := processGroup.processes[modelID]
|
if pm.matrix != nil {
|
||||||
if process != nil {
|
process, _ = pm.matrix.GetProcess(modelID)
|
||||||
var stateStr string
|
} else {
|
||||||
switch process.CurrentState() {
|
processGroup := pm.findGroupByModelName(modelID)
|
||||||
case StateReady:
|
if processGroup != nil {
|
||||||
stateStr = "ready"
|
process = processGroup.processes[modelID]
|
||||||
case StateStarting:
|
}
|
||||||
stateStr = "starting"
|
}
|
||||||
case StateStopping:
|
if process != nil {
|
||||||
stateStr = "stopping"
|
switch process.CurrentState() {
|
||||||
case StateShutdown:
|
case StateReady:
|
||||||
stateStr = "shutdown"
|
state = "ready"
|
||||||
case StateStopped:
|
case StateStarting:
|
||||||
stateStr = "stopped"
|
state = "starting"
|
||||||
default:
|
case StateStopping:
|
||||||
stateStr = "unknown"
|
state = "stopping"
|
||||||
}
|
case StateShutdown:
|
||||||
state = stateStr
|
state = "shutdown"
|
||||||
|
case StateStopped:
|
||||||
|
state = "stopped"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
models = append(models, Model{
|
models = append(models, Model{
|
||||||
@@ -79,9 +85,22 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
|||||||
Description: pm.config.Models[modelID].Description,
|
Description: pm.config.Models[modelID].Description,
|
||||||
State: state,
|
State: state,
|
||||||
Unlisted: pm.config.Models[modelID].Unlisted,
|
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||||
|
Aliases: pm.config.Models[modelID].Aliases,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Iterate over the peer models
|
||||||
|
if pm.peerProxy != nil {
|
||||||
|
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
models = append(models, Model{
|
||||||
|
Id: modelID,
|
||||||
|
PeerID: peerID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,6 +110,7 @@ const (
|
|||||||
msgTypeModelStatus messageType = "modelStatus"
|
msgTypeModelStatus messageType = "modelStatus"
|
||||||
msgTypeLogData messageType = "logData"
|
msgTypeLogData messageType = "logData"
|
||||||
msgTypeMetrics messageType = "metrics"
|
msgTypeMetrics messageType = "metrics"
|
||||||
|
msgTypeInFlight messageType = "inflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
type messageEnvelope struct {
|
type messageEnvelope struct {
|
||||||
@@ -150,6 +170,18 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sendInFlight := func(total int) {
|
||||||
|
jsonData, err := json.Marshal(gin.H{"total": total})
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Send updated models list
|
* Send updated models list
|
||||||
*/
|
*/
|
||||||
@@ -177,11 +209,19 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
|||||||
sendMetrics([]TokenMetrics{e.Metrics})
|
sendMetrics([]TokenMetrics{e.Metrics})
|
||||||
})()
|
})()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send in-flight request stats related to token stats "Waiting: N" count.
|
||||||
|
*/
|
||||||
|
defer event.On(func(e InFlightRequestsEvent) {
|
||||||
|
sendInFlight(e.Total)
|
||||||
|
})()
|
||||||
|
|
||||||
// send initial batch of data
|
// send initial batch of data
|
||||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||||
sendModels()
|
sendModels()
|
||||||
sendMetrics(pm.metricsMonitor.getMetrics())
|
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||||
|
sendInFlight(pm.inFlightCounter.Current())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -215,18 +255,23 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup := pm.findGroupByModelName(realModelName)
|
var stopErr error
|
||||||
if processGroup == nil {
|
if pm.matrix != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
stopErr = pm.matrix.StopProcess(realModelName, StopImmediately)
|
||||||
return
|
} else {
|
||||||
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
|
if processGroup == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stopErr = processGroup.StopProcess(realModelName, StopImmediately)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
if stopErr != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", stopErr.Error()))
|
||||||
return
|
return
|
||||||
} else {
|
|
||||||
c.String(http.StatusOK, "OK")
|
|
||||||
}
|
}
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||||
@@ -236,3 +281,35 @@ func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
|||||||
"build_date": pm.buildDate,
|
"build_date": pm.buildDate,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, exists := pm.metricsMonitor.getCompressedBytes(id)
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Vary", "Accept-Encoding")
|
||||||
|
|
||||||
|
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway
|
||||||
|
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd")
|
||||||
|
|
||||||
|
if hasZstd {
|
||||||
|
c.Header("Content-Encoding", "zstd")
|
||||||
|
c.Data(http.StatusOK, "application/json", data)
|
||||||
|
} else {
|
||||||
|
decompressed, err := decompressCapture(data)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", decompressed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
// prevent nginx from buffering streamed logs
|
// prevent nginx from buffering streamed logs
|
||||||
c.Header("X-Accel-Buffering", "no")
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
logMonitorId := c.Param("logMonitorID")
|
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
logger, err := pm.getLogger(logMonitorId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadRequest, err.Error())
|
c.String(http.StatusBadRequest, err.Error())
|
||||||
@@ -83,18 +83,25 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||||
var logger *LogMonitor
|
switch logMonitorId {
|
||||||
|
case "":
|
||||||
if logMonitorId == "" {
|
|
||||||
// maintain the default
|
// maintain the default
|
||||||
logger = pm.muxLogger
|
return pm.muxLogger, nil
|
||||||
} else if logMonitorId == "proxy" {
|
case "proxy":
|
||||||
logger = pm.proxyLogger
|
return pm.proxyLogger, nil
|
||||||
} else if logMonitorId == "upstream" {
|
case "upstream":
|
||||||
logger = pm.upstreamLogger
|
return pm.upstreamLogger, nil
|
||||||
} else {
|
default:
|
||||||
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
// search for a models specific logger using findModelInPath
|
||||||
}
|
// to handle model names with slashes (e.g., "author/model")
|
||||||
|
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
|
||||||
|
for _, group := range pm.processGroups {
|
||||||
|
if process, found := group.GetMember(name); found {
|
||||||
|
return process.Logger(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return logger, nil
|
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// selectEncoding chooses the best encoding based on Accept-Encoding header
|
||||||
|
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
|
||||||
|
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||||
|
if acceptEncoding == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||||
|
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||||
|
if enc == "br" {
|
||||||
|
return "br", ".br"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||||
|
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||||
|
if enc == "gzip" {
|
||||||
|
return "gzip", ".gz"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeCompressedFile serves a file with compression support.
|
||||||
|
// It checks for pre-compressed versions and serves them with proper headers.
|
||||||
|
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
|
||||||
|
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
|
||||||
|
|
||||||
|
// Try to serve compressed version if client supports it
|
||||||
|
if encoding != "" {
|
||||||
|
if cf, err := fs.Open(name + ext); err == nil {
|
||||||
|
defer cf.Close()
|
||||||
|
|
||||||
|
// Verify it's a regular file (not a directory)
|
||||||
|
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||||
|
// Set the content encoding header
|
||||||
|
w.Header().Set("Content-Encoding", encoding)
|
||||||
|
w.Header().Add("Vary", "Accept-Encoding")
|
||||||
|
|
||||||
|
// Get original file info for content type detection
|
||||||
|
origFile, err := fs.Open(name)
|
||||||
|
if err == nil {
|
||||||
|
origFile.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve the compressed file
|
||||||
|
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to serving the uncompressed file
|
||||||
|
file, err := fs.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
stat, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stat.IsDir() {
|
||||||
|
http.Error(w, "is a directory", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||||
|
}
|
||||||
@@ -0,0 +1,283 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServeCompressedFile_Brotli(t *testing.T) {
|
||||||
|
// Create test content
|
||||||
|
content := []byte("This is test content that should be compressed with brotli")
|
||||||
|
brContent := []byte("fake-brotli-compressed-data")
|
||||||
|
|
||||||
|
// Create a test filesystem
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"test.js": {Data: content, ModTime: time.Now()},
|
||||||
|
"test.js.br": {Data: brContent, ModTime: time.Now()},
|
||||||
|
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
|
||||||
|
}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "test.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that brotli is used (preferred over gzip)
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||||
|
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
|
||||||
|
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(body, brContent) {
|
||||||
|
t.Errorf("Expected brotli content, got %s", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeCompressedFile_Gzip(t *testing.T) {
|
||||||
|
// Create test content
|
||||||
|
content := []byte("This is test content that should be compressed with gzip")
|
||||||
|
gzContent := []byte("fake-gzip-compressed-data")
|
||||||
|
|
||||||
|
// Create a test filesystem without brotli
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"test.js": {Data: content, ModTime: time.Now()},
|
||||||
|
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
|
||||||
|
}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "test.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||||
|
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(body, gzContent) {
|
||||||
|
t.Errorf("Expected gzip content, got %s", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
|
||||||
|
// Create test content
|
||||||
|
content := []byte("This is uncompressed test content")
|
||||||
|
|
||||||
|
// Create a test filesystem without compressed versions
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"test.js": {Data: content, ModTime: time.Now()},
|
||||||
|
}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "test.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not have Content-Encoding header since we're serving uncompressed
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||||
|
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(body, content) {
|
||||||
|
t.Errorf("Expected original content, got %s", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
|
||||||
|
// Create test content
|
||||||
|
content := []byte("This is test content")
|
||||||
|
|
||||||
|
// Create a test filesystem with compressed versions
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"test.js": {Data: content, ModTime: time.Now()},
|
||||||
|
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
|
||||||
|
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
|
||||||
|
}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||||
|
// No Accept-Encoding header
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "test.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should serve uncompressed content
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||||
|
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(body, content) {
|
||||||
|
t.Errorf("Expected original content, got %s", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeCompressedFile_NotFound(t *testing.T) {
|
||||||
|
mapFS := fstest.MapFS{}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "nonexistent.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("Expected status 404, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectEncoding(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
acceptEncoding string
|
||||||
|
wantEncoding string
|
||||||
|
wantExt string
|
||||||
|
}{
|
||||||
|
{"br, gzip", "br", ".br"},
|
||||||
|
{"gzip, deflate", "gzip", ".gz"},
|
||||||
|
{"gzip", "gzip", ".gz"},
|
||||||
|
{"br", "br", ".br"},
|
||||||
|
{"", "", ""},
|
||||||
|
{"deflate", "", ""},
|
||||||
|
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
|
||||||
|
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
|
||||||
|
{"browser", "", ""},
|
||||||
|
{"compress, deflate", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
|
||||||
|
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
|
||||||
|
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
|
||||||
|
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with actual pre-compressed files from ui_dist
|
||||||
|
func TestServeCompressedFile_RealFiles(t *testing.T) {
|
||||||
|
// Check if ui_dist exists
|
||||||
|
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
|
||||||
|
t.Skip("ui_dist not found, skipping real file test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find a .js or .css file that has compressed versions
|
||||||
|
entries, err := os.ReadDir("./ui_dist/assets")
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Could not read ui_dist/assets: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var testFile string
|
||||||
|
for _, entry := range entries {
|
||||||
|
name := entry.Name()
|
||||||
|
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
|
||||||
|
// Check if compressed versions exist
|
||||||
|
base := strings.TrimSuffix(name, ".js")
|
||||||
|
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
|
||||||
|
testFile = "assets/" + name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if testFile == "" {
|
||||||
|
t.Skip("No suitable test file found with compressed versions")
|
||||||
|
}
|
||||||
|
|
||||||
|
fs := http.FS(os.DirFS("./ui_dist"))
|
||||||
|
|
||||||
|
// Test brotli
|
||||||
|
t.Run("brotli", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, testFile)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||||
|
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test gzip
|
||||||
|
t.Run("gzip", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, testFile)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||||
|
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's valid gzip
|
||||||
|
reader, err := gzip.NewReader(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected valid gzip content: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// Just read to verify it's valid
|
||||||
|
_, err = io.Copy(io.Discard, reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decompress gzip: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
node_modules
|
||||||
|
.vite
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
legacy-peer-deps=true
|
||||||
@@ -10,8 +10,8 @@
|
|||||||
<link rel="manifest" href="/site.webmanifest" />
|
<link rel="manifest" href="/site.webmanifest" />
|
||||||
<title>llama-swap</title>
|
<title>llama-swap</title>
|
||||||
</head>
|
</head>
|
||||||
<body >
|
<body>
|
||||||
<div id="root"></div>
|
<div id="app"></div>
|
||||||
<script type="module" src="/src/main.tsx"></script>
|
<script type="module" src="/src/main.ts"></script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
{
|
||||||
|
"name": "ui-svelte",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"start": "vite",
|
||||||
|
"build": "vite build --emptyOutDir",
|
||||||
|
"preview": "vite preview",
|
||||||
|
"check": "svelte-check --tsconfig ./tsconfig.json",
|
||||||
|
"test": "vitest run",
|
||||||
|
"test:watch": "vitest"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@sveltejs/vite-plugin-svelte": "^7.0.0",
|
||||||
|
"@tailwindcss/vite": "^4.1.8",
|
||||||
|
"@tsconfig/svelte": "^5.0.4",
|
||||||
|
"@types/hast": "^3.0.4",
|
||||||
|
"@types/node": "^25.1.0",
|
||||||
|
"svelte": "^5.46.4",
|
||||||
|
"svelte-check": "^4.1.4",
|
||||||
|
"tailwindcss": "^4.1.8",
|
||||||
|
"typescript": "~5.8.3",
|
||||||
|
"vite": "^8.0.0",
|
||||||
|
"vite-plugin-compression2": "^2.5.1",
|
||||||
|
"vitest": "^4.1.0"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"highlight.js": "^11.11.1",
|
||||||
|
"katex": "^0.16.28",
|
||||||
|
"lucide-svelte": "^0.563.0",
|
||||||
|
"rehype-katex": "^7.0.1",
|
||||||
|
"rehype-stringify": "^10.0.1",
|
||||||
|
"remark-gfm": "^4.0.1",
|
||||||
|
"remark-math": "^6.0.0",
|
||||||
|
"remark-parse": "^11.0.0",
|
||||||
|
"remark-rehype": "^11.1.2",
|
||||||
|
"svelte-spa-router": "^4.0.1",
|
||||||
|
"unified": "^11.0.5",
|
||||||
|
"unist-util-visit": "^5.1.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Before Width: | Height: | Size: 5.9 KiB After Width: | Height: | Size: 5.9 KiB |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 38 KiB After Width: | Height: | Size: 38 KiB |
|
Before Width: | Height: | Size: 6.5 KiB After Width: | Height: | Size: 6.5 KiB |
|
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 28 KiB |
@@ -0,0 +1,58 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { onMount } from "svelte";
|
||||||
|
import Router from "svelte-spa-router";
|
||||||
|
import Header from "./components/Header.svelte";
|
||||||
|
import LogViewer from "./routes/LogViewer.svelte";
|
||||||
|
import Models from "./routes/Models.svelte";
|
||||||
|
import Activity from "./routes/Activity.svelte";
|
||||||
|
import Playground from "./routes/Playground.svelte";
|
||||||
|
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
||||||
|
import { enableAPIEvents } from "./stores/api";
|
||||||
|
import { initScreenWidth, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
||||||
|
import { currentRoute } from "./stores/route";
|
||||||
|
|
||||||
|
const routes = {
|
||||||
|
"/": PlaygroundStub,
|
||||||
|
"/models": Models,
|
||||||
|
"/logs": LogViewer,
|
||||||
|
"/activity": Activity,
|
||||||
|
"*": PlaygroundStub,
|
||||||
|
};
|
||||||
|
|
||||||
|
function handleRouteLoaded(event: { detail: { route: string | RegExp } }) {
|
||||||
|
const route = event.detail.route;
|
||||||
|
currentRoute.set(typeof route === "string" ? route : "/");
|
||||||
|
}
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
document.documentElement.setAttribute("data-theme", $isDarkMode ? "dark" : "light");
|
||||||
|
});
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
const icon = $connectionState === "connecting" ? "\u{1F7E1}" : $connectionState === "connected" ? "\u{1F7E2}" : "\u{1F534}";
|
||||||
|
document.title = `${icon} ${$appTitle}`;
|
||||||
|
});
|
||||||
|
|
||||||
|
onMount(() => {
|
||||||
|
const cleanupScreenWidth = initScreenWidth();
|
||||||
|
enableAPIEvents(true);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
cleanupScreenWidth();
|
||||||
|
enableAPIEvents(false);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="flex flex-col h-screen">
|
||||||
|
<Header />
|
||||||
|
|
||||||
|
<main class="flex-1 overflow-auto p-4">
|
||||||
|
<div class="h-full" class:hidden={$currentRoute !== "/"}>
|
||||||
|
<Playground />
|
||||||
|
</div>
|
||||||
|
<div class="h-full" class:hidden={$currentRoute === "/"}>
|
||||||
|
<Router {routes} on:routeLoaded={handleRouteLoaded} />
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
</div>
|
||||||
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 4.0 KiB After Width: | Height: | Size: 4.0 KiB |
@@ -0,0 +1,97 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { inFlightRequests, metrics } from "../stores/api";
|
||||||
|
import { persistentStore } from "../stores/persistent";
|
||||||
|
import { calculateHistogramData } from "../lib/histogram";
|
||||||
|
import TokenHistogram from "./TokenHistogram.svelte";
|
||||||
|
|
||||||
|
const nf = new Intl.NumberFormat();
|
||||||
|
const histogramCollapsed = persistentStore<boolean>("activity-histogram-collapsed", false);
|
||||||
|
|
||||||
|
let stats = $derived.by(() => {
|
||||||
|
const totalRequests = $metrics.length;
|
||||||
|
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||||
|
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||||
|
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.cache_tokens, 0);
|
||||||
|
|
||||||
|
const promptPerSecond = $metrics.filter((m) => m.prompt_per_second > 0).map((m) => m.prompt_per_second);
|
||||||
|
|
||||||
|
const tokensPerSecond = $metrics.filter((m) => m.tokens_per_second > 0).map((m) => m.tokens_per_second);
|
||||||
|
|
||||||
|
const promptHistogramData =
|
||||||
|
promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null;
|
||||||
|
|
||||||
|
const genHistogramData =
|
||||||
|
tokensPerSecond.length > 0 ? calculateHistogramData(tokensPerSecond) : null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
totalRequests,
|
||||||
|
totalInputTokens,
|
||||||
|
totalOutputTokens,
|
||||||
|
totalCacheTokens,
|
||||||
|
inFlightRequests: $inFlightRequests,
|
||||||
|
promptHistogramData,
|
||||||
|
genHistogramData,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="card relative p-3">
|
||||||
|
<button
|
||||||
|
class="absolute top-2 right-2 w-6 h-6 flex items-center justify-center rounded-full border border-gray-300 dark:border-gray-600 text-gray-400 dark:text-gray-500 hover:text-gray-600 dark:hover:text-gray-300 hover:border-gray-400 dark:hover:border-gray-400 transition-colors"
|
||||||
|
onclick={() => ($histogramCollapsed = !$histogramCollapsed)}
|
||||||
|
title={$histogramCollapsed ? "Show histograms" : "Hide histograms"}
|
||||||
|
>
|
||||||
|
{#if $histogramCollapsed}
|
||||||
|
<svg class="w-3.5 h-3.5" viewBox="0 0 16 16" fill="currentColor">
|
||||||
|
<path d="M4.5 6l3.5 4 3.5-4H4.5z" />
|
||||||
|
</svg>
|
||||||
|
{:else}
|
||||||
|
<svg class="w-3 h-3" viewBox="0 0 16 16" fill="currentColor">
|
||||||
|
<path d="M3.5 3.5l9 9M12.5 3.5l-9 9" stroke="currentColor" stroke-width="2" stroke-linecap="round" fill="none" />
|
||||||
|
</svg>
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
{#if !$histogramCollapsed}
|
||||||
|
<div class="flex flex-col sm:flex-row gap-6 mb-3">
|
||||||
|
<div class="w-full sm:w-1/2 min-w-0">
|
||||||
|
<div class="text-sm font-medium text-gray-500 dark:text-gray-400 mb-1">Prompt Processing</div>
|
||||||
|
{#if stats.promptHistogramData}
|
||||||
|
<TokenHistogram
|
||||||
|
data={stats.promptHistogramData}
|
||||||
|
unit="prompt tokens/sec"
|
||||||
|
colorClass="text-amber-500 dark:text-amber-400"
|
||||||
|
/>
|
||||||
|
{:else}
|
||||||
|
<div class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">No prompt speed data yet</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
<div class="w-full sm:w-1/2 min-w-0">
|
||||||
|
<div class="text-sm font-medium text-gray-500 dark:text-gray-400 mb-1">Token Generation</div>
|
||||||
|
{#if stats.genHistogramData}
|
||||||
|
<TokenHistogram data={stats.genHistogramData} unit="tokens/sec" />
|
||||||
|
{:else}
|
||||||
|
<div class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">No generation speed data yet</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
<div class="grid grid-cols-4 gap-x-6 gap-y-1 text-sm">
|
||||||
|
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Requests</div>
|
||||||
|
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Cached</div>
|
||||||
|
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Processed</div>
|
||||||
|
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Generated</div>
|
||||||
|
<div class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
<span class="font-semibold">{nf.format(stats.totalRequests)}</span> completed,
|
||||||
|
<span class="font-semibold">{nf.format(stats.inFlightRequests)}</span> waiting
|
||||||
|
</div>
|
||||||
|
<div class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
<span class="font-semibold">{nf.format(stats.totalCacheTokens)}</span> tokens
|
||||||
|
</div>
|
||||||
|
<div class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
<span class="font-semibold">{nf.format(stats.totalInputTokens)}</span> tokens
|
||||||
|
</div>
|
||||||
|
<div class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
<span class="font-semibold">{nf.format(stats.totalOutputTokens)}</span> tokens
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
@@ -0,0 +1,453 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import type { ReqRespCapture } from "../lib/types";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
capture: ReqRespCapture | null;
|
||||||
|
open: boolean;
|
||||||
|
onclose: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
let { capture, open, onclose }: Props = $props();
|
||||||
|
|
||||||
|
let dialogEl: HTMLDialogElement | undefined = $state();
|
||||||
|
|
||||||
|
type BodyTab = "raw" | "pretty" | "chat";
|
||||||
|
let reqBodyTab: BodyTab = $state("pretty");
|
||||||
|
let respBodyTab: BodyTab = $state("pretty");
|
||||||
|
let copiedReq = $state(false);
|
||||||
|
let copiedResp = $state(false);
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
if (open && dialogEl) {
|
||||||
|
dialogEl.showModal();
|
||||||
|
} else if (!open && dialogEl) {
|
||||||
|
dialogEl.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Reset tabs when capture changes
|
||||||
|
$effect(() => {
|
||||||
|
if (capture) {
|
||||||
|
const reqCt = getContentType(capture.req_headers);
|
||||||
|
const respCt = getContentType(capture.resp_headers);
|
||||||
|
reqBodyTab = reqCt.includes("json") ? "pretty" : "raw";
|
||||||
|
respBodyTab = respCt.includes("text/event-stream")
|
||||||
|
? "chat"
|
||||||
|
: respCt.includes("json")
|
||||||
|
? "pretty"
|
||||||
|
: "raw";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
function handleDialogClose() {
|
||||||
|
onclose();
|
||||||
|
}
|
||||||
|
|
||||||
|
function decodeBody(body: string | null | undefined): string {
|
||||||
|
if (!body) return "";
|
||||||
|
try {
|
||||||
|
const binary = atob(body);
|
||||||
|
const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
|
||||||
|
return new TextDecoder().decode(bytes);
|
||||||
|
} catch {
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatJson(str: string): string {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(str);
|
||||||
|
return JSON.stringify(parsed, null, 2);
|
||||||
|
} catch {
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getContentType(
|
||||||
|
headers: Record<string, string> | null | undefined,
|
||||||
|
): string {
|
||||||
|
if (!headers) return "";
|
||||||
|
const ct = headers["Content-Type"] || headers["content-type"] || "";
|
||||||
|
return ct.toLowerCase();
|
||||||
|
}
|
||||||
|
|
||||||
|
function isImageContentType(contentType: string): boolean {
|
||||||
|
return contentType.startsWith("image/");
|
||||||
|
}
|
||||||
|
|
||||||
|
function isTextContentType(contentType: string): boolean {
|
||||||
|
return (
|
||||||
|
contentType.startsWith("text/") ||
|
||||||
|
contentType.includes("application/json") ||
|
||||||
|
contentType.includes("application/xml") ||
|
||||||
|
contentType.includes("application/javascript")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function getImageDataUrl(body: string, contentType: string): string {
|
||||||
|
const mimeType = contentType.split(";")[0].trim();
|
||||||
|
return `data:${mimeType};base64,${body}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SSEChat {
|
||||||
|
reasoning: string;
|
||||||
|
content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseSSEChat(text: string): SSEChat {
|
||||||
|
const result: SSEChat = { reasoning: "", content: "" };
|
||||||
|
for (const line of text.split("\n")) {
|
||||||
|
const trimmed = line.trim();
|
||||||
|
if (!trimmed || !trimmed.startsWith("data: ")) continue;
|
||||||
|
const data = trimmed.slice(6);
|
||||||
|
if (data === "[DONE]") continue;
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(data);
|
||||||
|
const delta = parsed.choices?.[0]?.delta;
|
||||||
|
if (delta?.content) result.content += delta.content;
|
||||||
|
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
|
||||||
|
if (delta?.reasoning) result.reasoning += delta.reasoning;
|
||||||
|
} catch {
|
||||||
|
// skip unparseable lines
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function copyToClipboard(text: string, type: "req" | "resp") {
|
||||||
|
try {
|
||||||
|
await navigator.clipboard.writeText(text);
|
||||||
|
if (type === "req") {
|
||||||
|
copiedReq = true;
|
||||||
|
setTimeout(() => (copiedReq = false), 1500);
|
||||||
|
} else {
|
||||||
|
copiedResp = true;
|
||||||
|
setTimeout(() => (copiedResp = false), 1500);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getCopyText(): string {
|
||||||
|
if (respBodyTab === "chat") {
|
||||||
|
let text = "";
|
||||||
|
if (sseChat.reasoning) text += sseChat.reasoning + "\n\n";
|
||||||
|
text += sseChat.content;
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
return displayedResponseBody;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request body derivations
|
||||||
|
let requestContentType = $derived(
|
||||||
|
capture ? getContentType(capture.req_headers) : "",
|
||||||
|
);
|
||||||
|
let isRequestJson = $derived(requestContentType.includes("json"));
|
||||||
|
|
||||||
|
let requestBodyRaw = $derived.by(() => {
|
||||||
|
if (!capture) return "";
|
||||||
|
return decodeBody(capture.req_body);
|
||||||
|
});
|
||||||
|
|
||||||
|
let requestBodyPretty = $derived.by(() => {
|
||||||
|
if (!isRequestJson) return requestBodyRaw;
|
||||||
|
return formatJson(requestBodyRaw);
|
||||||
|
});
|
||||||
|
|
||||||
|
let displayedRequestBody = $derived(
|
||||||
|
reqBodyTab === "pretty" ? requestBodyPretty : requestBodyRaw,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Response body derivations
|
||||||
|
let responseContentType = $derived(
|
||||||
|
capture ? getContentType(capture.resp_headers) : "",
|
||||||
|
);
|
||||||
|
let isResponseImage = $derived(isImageContentType(responseContentType));
|
||||||
|
let isResponseText = $derived(isTextContentType(responseContentType));
|
||||||
|
let isResponseJson = $derived(responseContentType.includes("json"));
|
||||||
|
let isSSE = $derived(responseContentType.includes("text/event-stream"));
|
||||||
|
|
||||||
|
let responseBodyRaw = $derived.by(() => {
|
||||||
|
if (!capture) return "";
|
||||||
|
return decodeBody(capture.resp_body);
|
||||||
|
});
|
||||||
|
|
||||||
|
let responseBodyPretty = $derived.by(() => {
|
||||||
|
if (!isResponseJson) return responseBodyRaw;
|
||||||
|
return formatJson(responseBodyRaw);
|
||||||
|
});
|
||||||
|
|
||||||
|
let sseChat = $derived.by(() => {
|
||||||
|
if (!isSSE || !responseBodyRaw)
|
||||||
|
return { reasoning: "", content: "" } as SSEChat;
|
||||||
|
return parseSSEChat(responseBodyRaw);
|
||||||
|
});
|
||||||
|
|
||||||
|
let displayedResponseBody = $derived.by(() => {
|
||||||
|
if (respBodyTab === "pretty") return responseBodyPretty;
|
||||||
|
return responseBodyRaw;
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<dialog
|
||||||
|
bind:this={dialogEl}
|
||||||
|
onclose={handleDialogClose}
|
||||||
|
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-4xl w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
|
||||||
|
>
|
||||||
|
{#if capture}
|
||||||
|
<div class="flex flex-col max-h-[90vh]">
|
||||||
|
<div
|
||||||
|
class="flex justify-between items-center p-4 border-b border-card-border"
|
||||||
|
>
|
||||||
|
<h2 class="text-xl font-bold pb-0">Capture #{capture.id + 1}{#if capture.req_path} <span class="text-base font-mono font-normal text-txtsecondary">{capture.req_path}</span>{/if}</h2>
|
||||||
|
<button
|
||||||
|
onclick={() => dialogEl?.close()}
|
||||||
|
class="text-txtsecondary hover:text-txtmain text-2xl leading-none"
|
||||||
|
>
|
||||||
|
×
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="overflow-y-auto flex-1 p-4 space-y-4">
|
||||||
|
<!-- Request Headers -->
|
||||||
|
<details class="group" open>
|
||||||
|
<summary
|
||||||
|
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||||
|
>
|
||||||
|
Request Headers
|
||||||
|
</summary>
|
||||||
|
<div
|
||||||
|
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
|
||||||
|
>
|
||||||
|
<table class="w-full text-sm">
|
||||||
|
<tbody>
|
||||||
|
{#each Object.entries(capture.req_headers || {}) as [key, value]}
|
||||||
|
<tr class="border-b border-card-border-inner last:border-0">
|
||||||
|
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
|
||||||
|
>{key}</td
|
||||||
|
>
|
||||||
|
<td class="px-3 py-1 font-mono break-all">{value}</td>
|
||||||
|
</tr>
|
||||||
|
{/each}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<!-- Request Body -->
|
||||||
|
<details class="group" open>
|
||||||
|
<summary
|
||||||
|
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||||
|
>
|
||||||
|
Request Body
|
||||||
|
</summary>
|
||||||
|
{#if requestBodyRaw}
|
||||||
|
<div class="mt-2 flex items-center justify-between">
|
||||||
|
<div class="flex gap-1">
|
||||||
|
{#if isRequestJson}
|
||||||
|
<button
|
||||||
|
class="tab-btn"
|
||||||
|
class:tab-btn-active={reqBodyTab === "pretty"}
|
||||||
|
onclick={() => (reqBodyTab = "pretty")}>Pretty</button
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
class="tab-btn"
|
||||||
|
class:tab-btn-active={reqBodyTab === "raw"}
|
||||||
|
onclick={() => (reqBodyTab = "raw")}>Raw</button
|
||||||
|
>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
class="tab-btn"
|
||||||
|
onclick={() =>
|
||||||
|
copyToClipboard(displayedRequestBody, "req")}
|
||||||
|
>
|
||||||
|
{#if copiedReq}
|
||||||
|
Copied!
|
||||||
|
{:else}
|
||||||
|
Copy
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||||
|
>
|
||||||
|
<pre
|
||||||
|
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedRequestBody}</pre>
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
<div
|
||||||
|
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||||
|
>
|
||||||
|
<pre class="p-3 text-sm font-mono whitespace-pre-wrap break-all"
|
||||||
|
>(empty)</pre
|
||||||
|
>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<!-- Response Headers -->
|
||||||
|
<details class="group" open>
|
||||||
|
<summary
|
||||||
|
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||||
|
>
|
||||||
|
Response Headers
|
||||||
|
</summary>
|
||||||
|
<div
|
||||||
|
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-48"
|
||||||
|
>
|
||||||
|
<table class="w-full text-sm">
|
||||||
|
<tbody>
|
||||||
|
{#each Object.entries(capture.resp_headers || {}) as [key, value]}
|
||||||
|
<tr class="border-b border-card-border-inner last:border-0">
|
||||||
|
<td class="px-3 py-1 font-mono text-primary whitespace-nowrap"
|
||||||
|
>{key}</td
|
||||||
|
>
|
||||||
|
<td class="px-3 py-1 font-mono break-all">{value}</td>
|
||||||
|
</tr>
|
||||||
|
{/each}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<!-- Response Body -->
|
||||||
|
<details class="group" open>
|
||||||
|
<summary
|
||||||
|
class="cursor-pointer font-semibold text-sm uppercase tracking-wider text-txtsecondary hover:text-txtmain"
|
||||||
|
>
|
||||||
|
Response Body
|
||||||
|
</summary>
|
||||||
|
{#if isResponseImage && capture.resp_body}
|
||||||
|
<div
|
||||||
|
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||||
|
>
|
||||||
|
<div class="p-3 flex justify-center">
|
||||||
|
<img
|
||||||
|
src={getImageDataUrl(capture.resp_body, responseContentType)}
|
||||||
|
alt="Response"
|
||||||
|
class="max-w-full h-auto"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{:else if isSSE || isResponseText}
|
||||||
|
<div class="mt-2 flex items-center justify-between">
|
||||||
|
<div class="flex gap-1">
|
||||||
|
{#if isSSE}
|
||||||
|
<button
|
||||||
|
class="tab-btn"
|
||||||
|
class:tab-btn-active={respBodyTab === "chat"}
|
||||||
|
onclick={() => (respBodyTab = "chat")}>Chat</button
|
||||||
|
>
|
||||||
|
{/if}
|
||||||
|
{#if isResponseJson}
|
||||||
|
<button
|
||||||
|
class="tab-btn"
|
||||||
|
class:tab-btn-active={respBodyTab === "pretty"}
|
||||||
|
onclick={() => (respBodyTab = "pretty")}>Pretty</button
|
||||||
|
>
|
||||||
|
{/if}
|
||||||
|
{#if isSSE || isResponseJson}
|
||||||
|
<button
|
||||||
|
class="tab-btn"
|
||||||
|
class:tab-btn-active={respBodyTab === "raw"}
|
||||||
|
onclick={() => (respBodyTab = "raw")}>Raw</button
|
||||||
|
>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
class="tab-btn"
|
||||||
|
onclick={() => copyToClipboard(getCopyText(), "resp")}
|
||||||
|
>
|
||||||
|
{#if copiedResp}
|
||||||
|
Copied!
|
||||||
|
{:else}
|
||||||
|
Copy
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
class="mt-1 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||||
|
>
|
||||||
|
{#if respBodyTab === "chat"}
|
||||||
|
<div class="p-3 text-sm space-y-3">
|
||||||
|
{#if sseChat.reasoning}
|
||||||
|
<div>
|
||||||
|
<div
|
||||||
|
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
|
||||||
|
>
|
||||||
|
Reasoning
|
||||||
|
</div>
|
||||||
|
<pre
|
||||||
|
class="font-mono whitespace-pre-wrap break-all text-txtsecondary">{sseChat.reasoning}</pre>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
{#if sseChat.content}
|
||||||
|
<div>
|
||||||
|
{#if sseChat.reasoning}
|
||||||
|
<div
|
||||||
|
class="text-xs font-semibold uppercase tracking-wider text-txtsecondary mb-1"
|
||||||
|
>
|
||||||
|
Response
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
<pre
|
||||||
|
class="font-mono whitespace-pre-wrap break-all">{sseChat.content}</pre>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
{#if !sseChat.reasoning && !sseChat.content}
|
||||||
|
<pre class="font-mono">(empty)</pre>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
<pre
|
||||||
|
class="p-3 text-sm font-mono whitespace-pre-wrap break-all">{displayedResponseBody || "(empty)"}</pre>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
{:else if responseBodyRaw}
|
||||||
|
<div
|
||||||
|
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||||
|
>
|
||||||
|
<div class="p-3 text-sm text-txtsecondary italic">
|
||||||
|
(binary data - {responseContentType || "unknown content type"})
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
<div
|
||||||
|
class="mt-2 bg-background rounded border border-card-border overflow-auto max-h-96"
|
||||||
|
>
|
||||||
|
<pre class="p-3 text-sm font-mono">(empty)</pre>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</details>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="p-4 border-t border-card-border flex justify-end">
|
||||||
|
<button onclick={() => dialogEl?.close()} class="btn"> Close </button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</dialog>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
.tab-btn {
|
||||||
|
padding: 2px 10px;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
color: var(--color-txtsecondary);
|
||||||
|
cursor: pointer;
|
||||||
|
border: 1px solid transparent;
|
||||||
|
background: transparent;
|
||||||
|
transition: all 0.15s;
|
||||||
|
}
|
||||||
|
.tab-btn:hover {
|
||||||
|
color: var(--color-txtmain);
|
||||||
|
background: var(--color-secondary);
|
||||||
|
}
|
||||||
|
.tab-btn-active {
|
||||||
|
color: var(--color-primary);
|
||||||
|
background: color-mix(in srgb, var(--color-primary) 12%, transparent);
|
||||||
|
border-color: color-mix(in srgb, var(--color-primary) 25%, transparent);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { connectionState } from "../stores/theme";
|
||||||
|
import { versionInfo } from "../stores/api";
|
||||||
|
|
||||||
|
let eventStatusColor = $derived.by(() => {
|
||||||
|
switch ($connectionState) {
|
||||||
|
case "connected":
|
||||||
|
return "bg-emerald-500";
|
||||||
|
case "connecting":
|
||||||
|
return "bg-amber-500";
|
||||||
|
case "disconnected":
|
||||||
|
default:
|
||||||
|
return "bg-red-500";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let tooltipText = $derived(
|
||||||
|
`Event Stream: ${$connectionState ?? "unknown"}\nAPI Version: ${$versionInfo?.version ?? "unknown"}\nCommit Hash: ${$versionInfo?.commit?.substring(0, 7) ?? "unknown"}\nBuild Date: ${$versionInfo?.build_date ?? "unknown"}`
|
||||||
|
);
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="flex items-center" title={tooltipText}>
|
||||||
|
<span class="inline-block w-3 h-3 rounded-full {eventStatusColor} mr-2"></span>
|
||||||
|
</div>
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { link } from "svelte-spa-router";
|
||||||
|
import { screenWidth, toggleTheme, isDarkMode, appTitle, isNarrow } from "../stores/theme";
|
||||||
|
import { currentRoute } from "../stores/route";
|
||||||
|
import { playgroundActivity } from "../stores/playgroundActivity";
|
||||||
|
import ConnectionStatus from "./ConnectionStatus.svelte";
|
||||||
|
|
||||||
|
function handleTitleChange(newTitle: string): void {
|
||||||
|
const sanitized = newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap";
|
||||||
|
appTitle.set(sanitized);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleKeyDown(e: KeyboardEvent): void {
|
||||||
|
if (e.key === "Enter") {
|
||||||
|
e.preventDefault();
|
||||||
|
const target = e.currentTarget as HTMLElement;
|
||||||
|
handleTitleChange(target.textContent || "(set title)");
|
||||||
|
target.blur();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleBlur(e: FocusEvent): void {
|
||||||
|
const target = e.currentTarget as HTMLElement;
|
||||||
|
handleTitleChange(target.textContent || "(set title)");
|
||||||
|
}
|
||||||
|
|
||||||
|
function isActive(path: string, current: string): boolean {
|
||||||
|
return path === "/" ? current === "/" : current.startsWith(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<header
|
||||||
|
class="flex items-center justify-between bg-surface border-b border-border px-4 {$isNarrow
|
||||||
|
? 'py-1 h-[60px]'
|
||||||
|
: 'p-2 h-[75px]'}"
|
||||||
|
>
|
||||||
|
{#if $screenWidth !== "xs" && $screenWidth !== "sm"}
|
||||||
|
<h1
|
||||||
|
contenteditable="true"
|
||||||
|
class="p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
|
||||||
|
onblur={handleBlur}
|
||||||
|
onkeydown={handleKeyDown}
|
||||||
|
>
|
||||||
|
{$appTitle}
|
||||||
|
</h1>
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
<menu class="flex items-center gap-4 overflow-x-auto">
|
||||||
|
<a
|
||||||
|
href="/"
|
||||||
|
use:link
|
||||||
|
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold underline underline-offset-4' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
|
||||||
|
>
|
||||||
|
Playground
|
||||||
|
</a>
|
||||||
|
<a
|
||||||
|
href="/models"
|
||||||
|
use:link
|
||||||
|
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||||
|
class:font-semibold={isActive("/models", $currentRoute)}
|
||||||
|
class:underline={isActive("/models", $currentRoute)}
|
||||||
|
class:underline-offset-4={isActive("/models", $currentRoute)}
|
||||||
|
>
|
||||||
|
Models
|
||||||
|
</a>
|
||||||
|
<a
|
||||||
|
href="/activity"
|
||||||
|
use:link
|
||||||
|
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||||
|
class:font-semibold={isActive("/activity", $currentRoute)}
|
||||||
|
class:underline={isActive("/activity", $currentRoute)}
|
||||||
|
class:underline-offset-4={isActive("/activity", $currentRoute)}
|
||||||
|
>
|
||||||
|
Activity
|
||||||
|
</a>
|
||||||
|
<a
|
||||||
|
href="/logs"
|
||||||
|
use:link
|
||||||
|
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||||
|
class:font-semibold={isActive("/logs", $currentRoute)}
|
||||||
|
class:underline={isActive("/logs", $currentRoute)}
|
||||||
|
class:underline-offset-4={isActive("/logs", $currentRoute)}
|
||||||
|
>
|
||||||
|
Logs
|
||||||
|
</a>
|
||||||
|
<button onclick={toggleTheme} title="Toggle theme">
|
||||||
|
{#if $isDarkMode}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path
|
||||||
|
fill-rule="evenodd"
|
||||||
|
d="M9.528 1.718a.75.75 0 0 1 .162.819A8.97 8.97 0 0 0 9 6a9 9 0 0 0 9 9 8.97 8.97 0 0 0 3.463-.69.75.75 0 0 1 .981.98 10.503 10.503 0 0 1-9.694 6.46c-5.799 0-10.5-4.7-10.5-10.5 0-4.368 2.667-8.112 6.46-9.694a.75.75 0 0 1 .818.162Z"
|
||||||
|
clip-rule="evenodd"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{:else}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path
|
||||||
|
d="M12 2.25a.75.75 0 0 1 .75.75v2.25a.75.75 0 0 1-1.5 0V3a.75.75 0 0 1 .75-.75ZM7.5 12a4.5 4.5 0 1 1 9 0 4.5 4.5 0 0 1-9 0ZM18.894 6.166a.75.75 0 0 0-1.06-1.06l-1.591 1.59a.75.75 0 1 0 1.06 1.061l1.591-1.59ZM21.75 12a.75.75 0 0 1-.75.75h-2.25a.75.75 0 0 1 0-1.5H21a.75.75 0 0 1 .75.75ZM17.834 18.894a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 1 0-1.061 1.06l1.591 1.591ZM12 18a.75.75 0 0 1 .75.75V21a.75.75 0 0 1-1.5 0v-2.25A.75.75 0 0 1 12 18ZM7.758 17.303a.75.75 0 0 0-1.061-1.06l-1.591 1.59a.75.75 0 0 0 1.06 1.061l1.591-1.59ZM6 12a.75.75 0 0 1-.75.75H3a.75.75 0 0 1 0-1.5h2.25A.75.75 0 0 1 6 12ZM6.697 7.757a.75.75 0 0 0 1.06-1.06l-1.59-1.591a.75.75 0 0 0-1.061 1.06l1.59 1.591Z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
<ConnectionStatus />
|
||||||
|
</menu>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
.activity-link {
|
||||||
|
background: linear-gradient(90deg, #6366f1, #8b5cf6, #a855f7, #8b5cf6, #6366f1);
|
||||||
|
background-size: 200% 100%;
|
||||||
|
-webkit-background-clip: text;
|
||||||
|
background-clip: text;
|
||||||
|
-webkit-text-fill-color: transparent;
|
||||||
|
animation: gradient-shift 2s linear infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes gradient-shift {
|
||||||
|
0% {
|
||||||
|
background-position: 0% 50%;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
background-position: 200% 50%;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { persistentStore } from "../stores/persistent";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
id: string;
|
||||||
|
title: string;
|
||||||
|
logData: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
let { id, title, logData }: Props = $props();
|
||||||
|
|
||||||
|
let filterRegex = $state("");
|
||||||
|
|
||||||
|
// Create persistent stores for this panel (id is intentionally captured at init time)
|
||||||
|
// svelte-ignore state_referenced_locally
|
||||||
|
const fontSizeStore = persistentStore<"xxs" | "xs" | "small" | "normal">(`logPanel-${id}-fontSize`, "normal");
|
||||||
|
// svelte-ignore state_referenced_locally
|
||||||
|
const wrapTextStore = persistentStore<boolean>(`logPanel-${id}-wrapText`, false);
|
||||||
|
// svelte-ignore state_referenced_locally
|
||||||
|
const showFilterStore = persistentStore<boolean>(`logPanel-${id}-showFilter`, false);
|
||||||
|
|
||||||
|
let textWrapClass = $derived($wrapTextStore ? "whitespace-pre-wrap" : "whitespace-pre");
|
||||||
|
|
||||||
|
function toggleFontSize(): void {
|
||||||
|
fontSizeStore.update((prev) => {
|
||||||
|
switch (prev) {
|
||||||
|
case "xxs": return "xs";
|
||||||
|
case "xs": return "small";
|
||||||
|
case "small": return "normal";
|
||||||
|
case "normal": return "xxs";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleWrapText(): void {
|
||||||
|
wrapTextStore.update((prev) => !prev);
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleFilter(): void {
|
||||||
|
if ($showFilterStore) {
|
||||||
|
showFilterStore.set(false);
|
||||||
|
filterRegex = "";
|
||||||
|
} else {
|
||||||
|
showFilterStore.set(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let fontSizeClass = $derived.by(() => {
|
||||||
|
switch ($fontSizeStore) {
|
||||||
|
case "xxs": return "text-[0.5rem]";
|
||||||
|
case "xs": return "text-[0.75rem]";
|
||||||
|
case "small": return "text-[0.875rem]";
|
||||||
|
case "normal": return "text-base";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let filteredLogs = $derived.by(() => {
|
||||||
|
if (!filterRegex) return logData;
|
||||||
|
try {
|
||||||
|
const regex = new RegExp(filterRegex, "i");
|
||||||
|
return logData.split("\n").filter((line) => regex.test(line)).join("\n");
|
||||||
|
} catch {
|
||||||
|
return logData;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let preElement: HTMLPreElement;
|
||||||
|
let userScrolledUp = $state(false);
|
||||||
|
|
||||||
|
function handleScroll() {
|
||||||
|
if (!preElement) return;
|
||||||
|
const { scrollTop, scrollHeight, clientHeight } = preElement;
|
||||||
|
userScrolledUp = scrollHeight - scrollTop - clientHeight > 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto scroll to bottom when logs change, unless user has scrolled up
|
||||||
|
$effect(() => {
|
||||||
|
if (preElement && filteredLogs && !userScrolledUp) {
|
||||||
|
preElement.scrollTop = preElement.scrollHeight;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="rounded-lg overflow-hidden flex flex-col bg-gray-950/5 dark:bg-white/10 h-full w-full p-1">
|
||||||
|
<div class="p-4">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<h3 class="m-0 text-lg p-0">{title}</h3>
|
||||||
|
|
||||||
|
<div class="flex gap-2 items-center">
|
||||||
|
<button class="btn border-0" onclick={toggleFontSize} title="Change font size">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||||
|
<path d="M2 4v3h5v12h3V7h5V4H2zm19 5h-9v3h3v7h3v-7h3V9z"/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
<button class="btn border-0" onclick={toggleWrapText} title="Toggle text wrap">
|
||||||
|
{#if $wrapTextStore}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||||
|
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{:else}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||||
|
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h10.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
<button class="btn border-0" onclick={toggleFilter} title="Toggle filter">
|
||||||
|
{#if $showFilterStore}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-4 h-4">
|
||||||
|
<path fill-rule="evenodd" d="M10.5 3.75a6.75 6.75 0 1 0 0 13.5 6.75 6.75 0 0 0 0-13.5ZM2.25 10.5a8.25 8.25 0 1 1 14.59 5.28l4.69 4.69a.75.75 0 1 1-1.06 1.06l-4.69-4.69A8.25 8.25 0 0 1 2.25 10.5Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{:else}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-4 h-4">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="m21 21-5.197-5.197m0 0A7.5 7.5 0 1 0 5.196 5.196a7.5 7.5 0 0 0 10.607 10.607Z" />
|
||||||
|
</svg>
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{#if $showFilterStore}
|
||||||
|
<div class="mt-2 flex gap-2 items-center w-full">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
class="w-full text-sm border border-gray-950/10 dark:border-white/5 p-2 rounded outline-none"
|
||||||
|
placeholder="Filter logs (regex)..."
|
||||||
|
bind:value={filterRegex}
|
||||||
|
/>
|
||||||
|
<button class="pl-2" onclick={() => (filterRegex = "")} aria-label="Clear filter">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||||
|
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm-1.72 6.97a.75.75 0 1 0-1.06 1.06L10.94 12l-1.72 1.72a.75.75 0 1 0 1.06 1.06L12 13.06l1.72 1.72a.75.75 0 1 0 1.06-1.06L13.06 12l1.72-1.72a.75.75 0 1 0-1.06-1.06L12 10.94l-1.72-1.72Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
<div class="rounded-lg bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||||
|
<pre bind:this={preElement} onscroll={handleScroll} class="{textWrapClass} {fontSizeClass} h-full overflow-auto p-4">{filteredLogs}</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
@@ -0,0 +1,211 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { models, loadModel, unloadAllModels, unloadSingleModel } from "../stores/api";
|
||||||
|
import { isNarrow } from "../stores/theme";
|
||||||
|
import { persistentStore } from "../stores/persistent";
|
||||||
|
import type { Model } from "../lib/types";
|
||||||
|
|
||||||
|
let isUnloading = $state(false);
|
||||||
|
let menuOpen = $state(false);
|
||||||
|
|
||||||
|
const showUnlistedStore = persistentStore<boolean>("showUnlisted", true);
|
||||||
|
const showIdorNameStore = persistentStore<"id" | "name">("showIdorName", "id");
|
||||||
|
|
||||||
|
let filteredModels = $derived.by(() => {
|
||||||
|
const filtered = $models.filter((model) => $showUnlistedStore || !model.unlisted);
|
||||||
|
const peerModels = filtered.filter((m) => m.peerID);
|
||||||
|
|
||||||
|
// Group peer models by peerID
|
||||||
|
const grouped = peerModels.reduce(
|
||||||
|
(acc, model) => {
|
||||||
|
const peerId = model.peerID || "unknown";
|
||||||
|
if (!acc[peerId]) acc[peerId] = [];
|
||||||
|
acc[peerId].push(model);
|
||||||
|
return acc;
|
||||||
|
},
|
||||||
|
{} as Record<string, Model[]>
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
regularModels: filtered.filter((m) => !m.peerID),
|
||||||
|
peerModelsByPeerId: grouped,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
async function handleUnloadAllModels(): Promise<void> {
|
||||||
|
isUnloading = true;
|
||||||
|
try {
|
||||||
|
await unloadAllModels();
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
} finally {
|
||||||
|
setTimeout(() => (isUnloading = false), 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleIdorName(): void {
|
||||||
|
showIdorNameStore.update((prev) => (prev === "name" ? "id" : "name"));
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleShowUnlisted(): void {
|
||||||
|
showUnlistedStore.update((prev) => !prev);
|
||||||
|
}
|
||||||
|
|
||||||
|
function getModelDisplay(model: Model): string {
|
||||||
|
return $showIdorNameStore === "id" ? model.id : (model.name || model.id);
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="card h-full flex flex-col">
|
||||||
|
<div class="shrink-0">
|
||||||
|
<div class="flex justify-between items-baseline">
|
||||||
|
<h2 class={$isNarrow ? "text-xl" : ""}>Models</h2>
|
||||||
|
{#if $isNarrow}
|
||||||
|
<div class="relative">
|
||||||
|
<button class="btn text-base flex items-center gap-2 py-1" onclick={() => (menuOpen = !menuOpen)} aria-label="Toggle menu">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path fill-rule="evenodd" d="M3 6.75A.75.75 0 0 1 3.75 6h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 6.75ZM3 12a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75A.75.75 0 0 1 3 12Zm0 5.25a.75.75 0 0 1 .75-.75h16.5a.75.75 0 0 1 0 1.5H3.75a.75.75 0 0 1-.75-.75Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
{#if menuOpen}
|
||||||
|
<div class="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
|
||||||
|
<button
|
||||||
|
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onclick={() => { toggleIdorName(); menuOpen = false; }}
|
||||||
|
>
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{$showIdorNameStore === "id" ? "Show Name" : "Show ID"}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onclick={() => { toggleShowUnlisted(); menuOpen = false; }}
|
||||||
|
>
|
||||||
|
{#if $showUnlistedStore}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
|
||||||
|
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
|
||||||
|
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
|
||||||
|
</svg>
|
||||||
|
{:else}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
|
||||||
|
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{/if}
|
||||||
|
{$showUnlistedStore ? "Hide Unlisted" : "Show Unlisted"}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
class="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onclick={() => { handleUnloadAllModels(); menuOpen = false; }}
|
||||||
|
disabled={isUnloading}
|
||||||
|
>
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||||
|
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{isUnloading ? "Unloading..." : "Unload All"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
{#if !$isNarrow}
|
||||||
|
<div class="flex justify-between">
|
||||||
|
<div class="flex gap-2">
|
||||||
|
<button class="btn text-base flex items-center gap-2" onclick={toggleIdorName} style="line-height: 1.2">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path fill-rule="evenodd" d="M15.97 2.47a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1 0 1.06l-4.5 4.5a.75.75 0 1 1-1.06-1.06l3.22-3.22H7.5a.75.75 0 0 1 0-1.5h11.69l-3.22-3.22a.75.75 0 0 1 0-1.06Zm-7.94 9a.75.75 0 0 1 0 1.06l-3.22 3.22H16.5a.75.75 0 0 1 0 1.5H4.81l3.22 3.22a.75.75 0 1 1-1.06 1.06l-4.5-4.5a.75.75 0 0 1 0-1.06l4.5-4.5a.75.75 0 0 1 1.06 0Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{$showIdorNameStore === "id" ? "ID" : "Name"}
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<button class="btn text-base flex items-center gap-2" onclick={toggleShowUnlisted} style="line-height: 1.2">
|
||||||
|
{#if $showUnlistedStore}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path d="M12 15a3 3 0 1 0 0-6 3 3 0 0 0 0 6Z" />
|
||||||
|
<path fill-rule="evenodd" d="M1.323 11.447C2.811 6.976 7.028 3.75 12.001 3.75c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113-1.487 4.471-5.705 7.697-10.677 7.697-4.97 0-9.186-3.223-10.675-7.69a1.762 1.762 0 0 1 0-1.113ZM17.25 12a5.25 5.25 0 1 1-10.5 0 5.25 5.25 0 0 1 10.5 0Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{:else}
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
<path d="M3.53 2.47a.75.75 0 0 0-1.06 1.06l18 18a.75.75 0 1 0 1.06-1.06l-18-18ZM22.676 12.553a11.249 11.249 0 0 1-2.631 4.31l-3.099-3.099a5.25 5.25 0 0 0-6.71-6.71L7.759 4.577a11.217 11.217 0 0 1 4.242-.827c4.97 0 9.185 3.223 10.675 7.69.12.362.12.752 0 1.113Z" />
|
||||||
|
<path d="M15.75 12c0 .18-.013.357-.037.53l-4.244-4.243A3.75 3.75 0 0 1 15.75 12ZM12.53 15.713l-4.243-4.244a3.75 3.75 0 0 0 4.244 4.243Z" />
|
||||||
|
<path d="M6.75 12c0-.619.107-1.213.304-1.764l-3.1-3.1a11.25 11.25 0 0 0-2.63 4.31c-.12.362-.12.752 0 1.114 1.489 4.467 5.704 7.69 10.675 7.69 1.5 0 2.933-.294 4.242-.827l-2.477-2.477A5.25 5.25 0 0 1 6.75 12Z" />
|
||||||
|
</svg>
|
||||||
|
{/if}
|
||||||
|
unlisted
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<button class="btn text-base flex items-center gap-2" onclick={handleUnloadAllModels} disabled={isUnloading}>
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-6 h-6">
|
||||||
|
<path fill-rule="evenodd" d="M12 2.25c-5.385 0-9.75 4.365-9.75 9.75s4.365 9.75 9.75 9.75 9.75-4.365 9.75-9.75S17.385 2.25 12 2.25Zm.53 5.47a.75.75 0 0 0-1.06 0l-3 3a.75.75 0 1 0 1.06 1.06l1.72-1.72v5.69a.75.75 0 0 0 1.5 0v-5.69l1.72 1.72a.75.75 0 1 0 1.06-1.06l-3-3Z" clip-rule="evenodd" />
|
||||||
|
</svg>
|
||||||
|
{isUnloading ? "Unloading..." : "Unload All"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="flex-1 overflow-y-auto">
|
||||||
|
<table class="w-full">
|
||||||
|
<thead class="sticky top-0 bg-card z-10">
|
||||||
|
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||||
|
<th>{$showIdorNameStore === "id" ? "Model ID" : "Name"}</th>
|
||||||
|
<th></th>
|
||||||
|
<th>State</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{#each filteredModels.regularModels as model (model.id)}
|
||||||
|
<tr class="border-b hover:bg-secondary-hover border-gray-200">
|
||||||
|
<td class={model.unlisted ? "text-txtsecondary" : ""}>
|
||||||
|
<a href="/upstream/{model.id}/" class="font-semibold" target="_blank">
|
||||||
|
{getModelDisplay(model)}
|
||||||
|
</a>
|
||||||
|
{#if model.description}
|
||||||
|
<p class={model.unlisted ? "text-opacity-70" : ""}><em>{model.description}</em></p>
|
||||||
|
{/if}
|
||||||
|
{#if model.aliases && model.aliases.length > 0}
|
||||||
|
<p class="text-xs text-txtsecondary">Aliases: {model.aliases.join(", ")}</p>
|
||||||
|
{/if}
|
||||||
|
</td>
|
||||||
|
<td class="w-12">
|
||||||
|
{#if model.state === "stopped"}
|
||||||
|
<button class="btn btn--sm" onclick={() => loadModel(model.id)}>Load</button>
|
||||||
|
{:else}
|
||||||
|
<button class="btn btn--sm" onclick={() => unloadSingleModel(model.id)} disabled={model.state !== "ready"}>Unload</button>
|
||||||
|
{/if}
|
||||||
|
</td>
|
||||||
|
<td class="w-20">
|
||||||
|
<span class="w-16 text-center status status--{model.state}">{model.state}</span>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
{/each}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
{#if Object.keys(filteredModels.peerModelsByPeerId).length > 0}
|
||||||
|
<h3 class="mt-8 mb-2">Peer Models</h3>
|
||||||
|
{#each Object.entries(filteredModels.peerModelsByPeerId).sort(([a], [b]) => a.localeCompare(b)) as [peerId, peerModels] (peerId)}
|
||||||
|
<div class="mb-4">
|
||||||
|
<table class="w-full">
|
||||||
|
<thead class="sticky top-0 bg-card z-10">
|
||||||
|
<tr class="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||||
|
<th class="font-semibold">{peerId}</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{#each peerModels as model (model.id)}
|
||||||
|
<tr class="border-b hover:bg-secondary-hover border-gray-200">
|
||||||
|
<td class="pl-8 {model.unlisted ? 'text-txtsecondary' : ''}">
|
||||||
|
<span>{model.id}</span>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
{/each}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
{/each}
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
@@ -0,0 +1,152 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import type { Snippet } from "svelte";
|
||||||
|
import { onMount } from "svelte";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
direction: "horizontal" | "vertical";
|
||||||
|
storageKey: string;
|
||||||
|
leftPanel: Snippet;
|
||||||
|
rightPanel: Snippet;
|
||||||
|
defaultSize?: number;
|
||||||
|
minSize?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
let { direction, storageKey, leftPanel, rightPanel, defaultSize = 50, minSize = 5 }: Props = $props();
|
||||||
|
|
||||||
|
let containerRef: HTMLDivElement;
|
||||||
|
let isDragging = $state(false);
|
||||||
|
// svelte-ignore state_referenced_locally
|
||||||
|
let leftSize = $state(defaultSize);
|
||||||
|
|
||||||
|
// Load saved size from localStorage
|
||||||
|
onMount(() => {
|
||||||
|
const saved = localStorage.getItem(`panel-size-${storageKey}`);
|
||||||
|
if (saved) {
|
||||||
|
const parsed = parseFloat(saved);
|
||||||
|
if (!isNaN(parsed) && parsed >= minSize && parsed <= 100 - minSize) {
|
||||||
|
leftSize = parsed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
function saveSize(): void {
|
||||||
|
localStorage.setItem(`panel-size-${storageKey}`, String(leftSize));
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMouseDown(e: MouseEvent): void {
|
||||||
|
e.preventDefault();
|
||||||
|
isDragging = true;
|
||||||
|
document.addEventListener("mousemove", handleMouseMove);
|
||||||
|
document.addEventListener("mouseup", handleMouseUp);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleTouchStart(_e: TouchEvent): void {
|
||||||
|
isDragging = true;
|
||||||
|
document.addEventListener("touchmove", handleTouchMove);
|
||||||
|
document.addEventListener("touchend", handleTouchEnd);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMouseMove(e: MouseEvent): void {
|
||||||
|
if (!isDragging || !containerRef) return;
|
||||||
|
updateSize(e.clientX, e.clientY);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleTouchMove(e: TouchEvent): void {
|
||||||
|
if (!isDragging || !containerRef || e.touches.length === 0) return;
|
||||||
|
updateSize(e.touches[0].clientX, e.touches[0].clientY);
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateSize(clientX: number, clientY: number): void {
|
||||||
|
const rect = containerRef.getBoundingClientRect();
|
||||||
|
|
||||||
|
let newSize: number;
|
||||||
|
if (direction === "horizontal") {
|
||||||
|
newSize = ((clientX - rect.left) / rect.width) * 100;
|
||||||
|
} else {
|
||||||
|
newSize = ((clientY - rect.top) / rect.height) * 100;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clamp size
|
||||||
|
newSize = Math.max(minSize, Math.min(100 - minSize, newSize));
|
||||||
|
leftSize = newSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMouseUp(): void {
|
||||||
|
isDragging = false;
|
||||||
|
saveSize();
|
||||||
|
document.removeEventListener("mousemove", handleMouseMove);
|
||||||
|
document.removeEventListener("mouseup", handleMouseUp);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleTouchEnd(): void {
|
||||||
|
isDragging = false;
|
||||||
|
saveSize();
|
||||||
|
document.removeEventListener("touchmove", handleTouchMove);
|
||||||
|
document.removeEventListener("touchend", handleTouchEnd);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleKeyDown(e: KeyboardEvent): void {
|
||||||
|
const step = 2; // 2% increment for keyboard navigation
|
||||||
|
const key = e.key;
|
||||||
|
|
||||||
|
if (direction === "horizontal" && (key === "ArrowLeft" || key === "ArrowRight")) {
|
||||||
|
e.preventDefault();
|
||||||
|
const delta = key === "ArrowLeft" ? -step : step;
|
||||||
|
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
|
||||||
|
leftSize = newSize;
|
||||||
|
saveSize();
|
||||||
|
} else if (direction === "vertical" && (key === "ArrowUp" || key === "ArrowDown")) {
|
||||||
|
e.preventDefault();
|
||||||
|
const delta = key === "ArrowUp" ? -step : step;
|
||||||
|
const newSize = Math.max(minSize, Math.min(100 - minSize, leftSize + delta));
|
||||||
|
leftSize = newSize;
|
||||||
|
saveSize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let containerClass = $derived(direction === "horizontal" ? "flex-row" : "flex-col");
|
||||||
|
|
||||||
|
let handleClass = $derived(
|
||||||
|
direction === "horizontal"
|
||||||
|
? "w-2 h-full cursor-col-resize"
|
||||||
|
: "w-full h-2 cursor-row-resize"
|
||||||
|
);
|
||||||
|
|
||||||
|
let leftStyle = $derived(
|
||||||
|
direction === "horizontal"
|
||||||
|
? `width: ${leftSize}%; min-width: ${minSize}%`
|
||||||
|
: `height: ${leftSize}%; min-height: ${minSize}%`
|
||||||
|
);
|
||||||
|
|
||||||
|
let rightStyle = $derived(
|
||||||
|
direction === "horizontal"
|
||||||
|
? `width: ${100 - leftSize}%; min-width: ${minSize}%`
|
||||||
|
: `height: ${100 - leftSize}%; min-height: ${minSize}%`
|
||||||
|
);
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div bind:this={containerRef} class="flex {containerClass} h-full w-full gap-2">
|
||||||
|
<div style={leftStyle} class="overflow-hidden">
|
||||||
|
{@render leftPanel()}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- svelte-ignore a11y_no_noninteractive_tabindex -->
|
||||||
|
<!-- svelte-ignore a11y_no_noninteractive_element_interactions -->
|
||||||
|
<div
|
||||||
|
role="separator"
|
||||||
|
tabindex="0"
|
||||||
|
class="{handleClass} bg-primary hover:bg-success transition-colors rounded flex-shrink-0"
|
||||||
|
onmousedown={handleMouseDown}
|
||||||
|
ontouchstart={handleTouchStart}
|
||||||
|
onkeydown={handleKeyDown}
|
||||||
|
aria-label="Resize panels"
|
||||||
|
aria-orientation={direction}
|
||||||
|
aria-valuenow={Math.round(leftSize)}
|
||||||
|
aria-valuemin={minSize}
|
||||||
|
aria-valuemax={100 - minSize}
|
||||||
|
></div>
|
||||||
|
|
||||||
|
<div style={rightStyle} class="overflow-hidden">
|
||||||
|
{@render rightPanel()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
@@ -0,0 +1,145 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import type { HistogramData } from "../lib/types";
|
||||||
|
|
||||||
|
let {
|
||||||
|
data,
|
||||||
|
unit = "tokens/sec",
|
||||||
|
colorClass = "text-blue-500 dark:text-blue-400",
|
||||||
|
}: {
|
||||||
|
data: HistogramData;
|
||||||
|
unit?: string;
|
||||||
|
colorClass?: string;
|
||||||
|
} = $props();
|
||||||
|
|
||||||
|
const height = 250;
|
||||||
|
const padding = { top: 30, right: 20, bottom: 40, left: 75 };
|
||||||
|
const viewBoxWidth = 1200;
|
||||||
|
const chartWidth = viewBoxWidth - padding.left - padding.right;
|
||||||
|
const chartHeight = height - padding.top - padding.bottom;
|
||||||
|
|
||||||
|
let maxCount = $derived(Math.max(...data.bins));
|
||||||
|
let barWidth = $derived(chartWidth / data.bins.length);
|
||||||
|
let range = $derived(data.max - data.min);
|
||||||
|
|
||||||
|
function getXPosition(value: number): number {
|
||||||
|
return padding.left + ((value - data.min) / range) * chartWidth;
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="mt-2 w-full">
|
||||||
|
<svg viewBox="0 0 {viewBoxWidth} {height}" class="w-full h-auto" preserveAspectRatio="xMidYMid meet">
|
||||||
|
<!-- Y-axis -->
|
||||||
|
<line
|
||||||
|
x1={padding.left}
|
||||||
|
y1={padding.top}
|
||||||
|
x2={padding.left}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="1"
|
||||||
|
opacity="0.3"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<!-- Y-axis ticks and labels -->
|
||||||
|
{#each [0, 0.5, 1] as fraction}
|
||||||
|
{@const tickCount = Math.round(maxCount * fraction)}
|
||||||
|
{@const tickY = height - padding.bottom - fraction * chartHeight}
|
||||||
|
<line
|
||||||
|
x1={padding.left - 8}
|
||||||
|
y1={tickY}
|
||||||
|
x2={padding.left}
|
||||||
|
y2={tickY}
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="1"
|
||||||
|
opacity="0.4"
|
||||||
|
/>
|
||||||
|
<text x={padding.left - 10} y={tickY + 10} font-size="26" fill="currentColor" opacity="0.8" text-anchor="end">
|
||||||
|
{tickCount}
|
||||||
|
</text>
|
||||||
|
{/each}
|
||||||
|
|
||||||
|
<!-- X-axis -->
|
||||||
|
<line
|
||||||
|
x1={padding.left}
|
||||||
|
y1={height - padding.bottom}
|
||||||
|
x2={viewBoxWidth - padding.right}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="1"
|
||||||
|
opacity="0.3"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<!-- Histogram bars -->
|
||||||
|
{#each data.bins as count, i}
|
||||||
|
{@const barHeight = maxCount > 0 ? (count / maxCount) * chartHeight : 0}
|
||||||
|
{@const x = padding.left + i * barWidth}
|
||||||
|
{@const y = height - padding.bottom - barHeight}
|
||||||
|
{@const binStart = data.min + i * data.binSize}
|
||||||
|
{@const binEnd = binStart + data.binSize}
|
||||||
|
<g>
|
||||||
|
<rect
|
||||||
|
{x}
|
||||||
|
{y}
|
||||||
|
width={Math.max(barWidth - 1, 1)}
|
||||||
|
height={barHeight}
|
||||||
|
fill="currentColor"
|
||||||
|
opacity="0.6"
|
||||||
|
class="{colorClass} hover:opacity-90 transition-opacity cursor-pointer"
|
||||||
|
/>
|
||||||
|
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} ${unit}\nCount: ${count}`}</title>
|
||||||
|
</g>
|
||||||
|
{/each}
|
||||||
|
|
||||||
|
<!-- Percentile lines -->
|
||||||
|
<line
|
||||||
|
x1={getXPosition(data.p50)}
|
||||||
|
y1={padding.top}
|
||||||
|
x2={getXPosition(data.p50)}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
stroke-dasharray="4 2"
|
||||||
|
opacity="0.7"
|
||||||
|
class="text-gray-600 dark:text-gray-400"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<line
|
||||||
|
x1={getXPosition(data.p95)}
|
||||||
|
y1={padding.top}
|
||||||
|
x2={getXPosition(data.p95)}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
stroke-dasharray="4 2"
|
||||||
|
opacity="0.7"
|
||||||
|
class="text-orange-500 dark:text-orange-400"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<line
|
||||||
|
x1={getXPosition(data.p99)}
|
||||||
|
y1={padding.top}
|
||||||
|
x2={getXPosition(data.p99)}
|
||||||
|
y2={height - padding.bottom}
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
stroke-dasharray="4 2"
|
||||||
|
opacity="0.7"
|
||||||
|
class="text-green-500 dark:text-green-400"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<!-- X-axis labels -->
|
||||||
|
<text x={padding.left} y={height - 8} font-size="26" fill="currentColor" opacity="0.8" text-anchor="start">
|
||||||
|
{data.min.toFixed(1)}
|
||||||
|
</text>
|
||||||
|
|
||||||
|
<text
|
||||||
|
x={viewBoxWidth - padding.right}
|
||||||
|
y={height - 8}
|
||||||
|
font-size="26"
|
||||||
|
fill="currentColor"
|
||||||
|
opacity="0.8"
|
||||||
|
text-anchor="end"
|
||||||
|
>
|
||||||
|
{data.max.toFixed(1)}
|
||||||
|
</text>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
interface Props {
|
||||||
|
content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
let { content }: Props = $props();
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="relative group inline-block">
|
||||||
|
<span class="cursor-help">ⓘ</span>
|
||||||
|
<div
|
||||||
|
class="absolute top-full left-1/2 transform -translate-x-1/2 mt-2
|
||||||
|
px-3 py-2 bg-gray-900 text-white text-sm rounded-md
|
||||||
|
opacity-0 group-hover:opacity-100 transition-opacity
|
||||||
|
duration-200 pointer-events-none whitespace-nowrap z-50 normal-case"
|
||||||
|
>
|
||||||
|
{content}
|
||||||
|
<div class="absolute bottom-full left-1/2 transform -translate-x-1/2 border-4 border-transparent border-b-gray-900"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||