Compare commits
366 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e8d4384cd2 | |||
| ce28485be2 | |||
| 3cd7837b1f | |||
| 0b31ccacc1 | |||
| 5938dbee8f | |||
| 66639e83f7 | |||
| 625b296720 | |||
| 231e62291c | |||
| 57ac666598 | |||
| 69728301f5 | |||
| c176fa70f1 | |||
| 5e3c646829 | |||
| c3f0d43e6e | |||
| f6cf9f5844 | |||
| 121fd93ad8 | |||
| 17233e9278 | |||
| 4866d16c3e | |||
| 35193f82f1 | |||
| 40e39f7a86 | |||
| a9d840ffd7 | |||
| 7b2b82777f | |||
| d87f0ce2c5 | |||
| 06bc6a614c | |||
| a37b4866d8 | |||
| 981910d734 | |||
| a185efe37e | |||
| 1dd1aadf93 | |||
| 955900972a | |||
| c2c8cfaf81 | |||
| 1e440770ea | |||
| c794273c83 | |||
| 6574a52cbb | |||
| 8fabc75634 | |||
| e5e7391b6d | |||
| 2c282dccad | |||
| 916d13f5bd | |||
| a3725e7d09 | |||
| 15bd55d3a9 | |||
| c3c258a55d | |||
| 29a38fde0d | |||
| d569681daa | |||
| 24efdb76b1 | |||
| cc77139ff8 | |||
| 390a35bf93 | |||
| 181f71ca11 | |||
| 49546e2cf2 | |||
| 2c078964f4 | |||
| 175bb36fb1 | |||
| aedb640471 | |||
| 2f377f6dc6 | |||
| 64e4c79fc3 | |||
| 19fb5f35e9 | |||
| b45102bde8 | |||
| 1688bdd1e9 | |||
| d33d51fa75 | |||
| e3bf065574 | |||
| 3e52144058 | |||
| d5e52d7d00 | |||
| 17e5263a76 | |||
| 8d6d949ec3 | |||
| b5fde8eb6d | |||
| 7eef5defb8 | |||
| bc01e6f539 | |||
| 0462e3dc3f | |||
| 7b20fc011b | |||
| 20738f3623 | |||
| cdea7d16bd | |||
| 5de387dbf9 | |||
| 6f8e7ccb57 | |||
| 4384315b44 | |||
| 6439ab1515 | |||
| f94226122c | |||
| 7493618fdc | |||
| 205efd40a1 | |||
| 14207f8492 | |||
| 4e850c2834 | |||
| 75fced579e | |||
| b73f367f22 | |||
| 8f2137c72b | |||
| 124007cc98 | |||
| eb5bfff0b0 | |||
| 3edb180c08 | |||
| 66d555e625 | |||
| 4f863fd9fc | |||
| 267c030457 | |||
| c19309fe7e | |||
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f | |||
| 37d74efc2d | |||
| 22e098ac8b | |||
| 9864f9f517 | |||
| 53b32f3601 | |||
| 565c44766d | |||
| e6a9e210ba | |||
| d3f329f924 | |||
| 98879b38c1 | |||
| 7b3b0f5eae | |||
| 021ccceef1 | |||
| f03871c50a | |||
| dc00d17abe | |||
| dea98733c3 | |||
| bccce5fa19 | |||
| c968da1b73 | |||
| a883d68d4f | |||
| b1dec8b735 | |||
| 06523d8c1e | |||
| 86e9b93c37 | |||
| 3acace810f | |||
| 554d29e87d | |||
| 3567b7df08 | |||
| 38738525c9 | |||
| c0fc858193 | |||
| b429349e8a | |||
| eab2efd7b5 | |||
| 6aedbe121a | |||
| b24467ab89 | |||
| 12b69fb718 | |||
| f91a8b2462 | |||
| a89b803d4a | |||
| f852689104 | |||
| e250e71e59 | |||
| d18dc26d01 | |||
| 8357714421 | |||
| c07179d6e2 | |||
| 7ff50631e0 | |||
| 9fc0431531 | |||
| 6516532568 | |||
| d58a8b85bf | |||
| caf9e98b1e | |||
| 539278343b | |||
| 00b738cd0f | |||
| 70930e4e91 | |||
| 1f6179110c | |||
| 216c40b951 | |||
| 9e3d491c85 | |||
| 1a84926505 | |||
| fc3bb716df | |||
| c36986fef6 | |||
| 558801db1a | |||
| b21dee27c1 | |||
| f58c8c8ec5 | |||
| 954e2dee73 | |||
| a533aec736 | |||
| 97b17fc47d | |||
| 2457840698 | |||
| 7f55494151 | |||
| 831a90d3b0 | |||
| 977f1856bb | |||
| 52b329f7bc | |||
| 57803fd3aa | |||
| c55d0cc842 | |||
| 7acbaf4712 | |||
| fcc5ad135a | |||
| 305e5a0031 | |||
| 04fc67354a | |||
| 4662cf7699 | |||
| 5dc6b3e6d9 | |||
| 74c69f39ef | |||
| a186318892 | |||
| c4e4d5e1e9 | |||
| 7985e94ba4 | |||
| 74556c3a36 | |||
| 5c381e4b30 | |||
| 10569ed546 | |||
| 5b10b3c23f | |||
| 45ea792a3a | |||
| 1bc2802353 | |||
| 701476c0c4 | |||
| 5c63e0066c | |||
| 8be5073c51 | |||
| 6307bd3205 | |||
| 558a72de17 | |||
| dc42cf366d | |||
| ba0a81937a | |||
| 574fdfabb4 | |||
| 5172cb2e12 | |||
| 5672cb03fd | |||
| 0f583163f7 | |||
| 7905fa9ea3 | |||
| bbaf172956 | |||
| fd50932dbc | |||
| 8c693e7fcf | |||
| 8f2af26a41 | |||
| 01d4838fb3 | |||
| accd65294b | |||
| 7472a25864 | |||
| cce0bc6aa1 | |||
| 36e25125e8 | |||
| 9a54273d15 | |||
| 87dce5f8f6 | |||
| 307e619521 | |||
| 6299c1b874 | |||
| a906cd459b | |||
| 78b2bc3dbc | |||
| 6a058e4191 | |||
| 1921e570d7 | |||
| c867a6c9a2 | |||
| 3bd1b23ce0 | |||
| 10606abf89 | |||
| fefd14903d | |||
| 717d64e336 | |||
| 285191e655 | |||
| 4236cec03a | |||
| 756193d0dd | |||
| a6b2e930d8 | |||
| 9e02c22ff8 | |||
| 0bdbf2fdc1 | |||
| 49035e2e8e | |||
| 9963ae18bf | |||
| 2ae48c713b | |||
| 54c519e365 | |||
| 3fce9ee0e9 | |||
| 5899ae7966 | |||
| 591a9cdf4d | |||
| 9a3c656738 | |||
| 75015f82ea | |||
| cc33b6c270 | |||
| 4fa12a429c | |||
| 2dc0ca0663 | |||
| a84098d3b4 | |||
| 4d02ccd26a | |||
| dfd47eeac4 | |||
| 1ac6499c08 | |||
| 25f3dc25e7 | |||
| 8422e4e6a1 | |||
| 02ee29d881 | |||
| b2a891f8f4 | |||
| 8d2b568897 | |||
| fb44cf4e08 | |||
| 02aee4e86d | |||
| f45896d395 | |||
| f7e46a359f | |||
| c260907415 | |||
| b83a5fa291 | |||
| 6e2ff28d59 | |||
| a8b81f2799 | |||
| f9ee7156dc | |||
| 2d00120781 | |||
| afc9aef058 | |||
| d7b390df74 | |||
| 5025c2f1f3 | |||
| e3a0b013c1 | |||
| f5763a94a0 | |||
| 8ada72eb57 | |||
| 2441b383d3 | |||
| 25f251699c | |||
| 7f37bcc6eb | |||
| 519c3a4d22 | |||
| 9dc4bcb46c | |||
| cb876c143b | |||
| bc652709a5 | |||
| 9548931258 | |||
| 5c5a5da664 | |||
| aa9ef59aa5 | |||
| 09e52c0500 | |||
| ca9063ffbe | |||
| 21d7973d11 | |||
| cc450e9c5f | |||
| 27465fe053 | |||
| 9667989727 | |||
| d9a1ddea0d | |||
| e7ab024ca0 | |||
| 448ccae959 | |||
| ec0348e431 | |||
| 06eda7f591 | |||
| 5fad24c16f | |||
| 8404244fab | |||
| 712cd01081 | |||
| 1f7aa359b1 | |||
| b138d6cf25 | |||
| fb7c808082 | |||
| a7e640b0f7 | |||
| 593604dfdc | |||
| b8f888f864 | |||
| 192b2ae621 | |||
| b7f8cb5094 | |||
| a23da6eb57 | |||
| 4c3aa40564 | |||
| 84e2c07a7e | |||
| 680af28bcc | |||
| d94db42ffe | |||
| 93cd83c55c | |||
| 5565fca3ac | |||
| d625ab8d92 | |||
| a3f82c140b | |||
| 5c97299e7b | |||
| 671c1a5a7b | |||
| 52c0196e0f | |||
| 3201a68a04 | |||
| 3ac94ad20e | |||
| 60355bf74a | |||
| 9b2ed244e2 | |||
| eeb72297f7 | |||
| eabfe70cc6 | |||
| 29cd98878d | |||
| b3d331da0d | |||
| 62275e078d | |||
| 88916059e1 | |||
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae | |||
| 1e25b44a06 | |||
| 0815bb4cc3 | |||
| 7187cfe52e | |||
| 24089d2d9c | |||
| ebabe55ff3 | |||
| 41a338297c | |||
| 7e3353efeb | |||
| 4ed58fb173 | |||
| f5a2be698d | |||
| f5e6ec3b7a | |||
| 3f462da146 | |||
| 48bd766536 | |||
| 8d319da4dd | |||
| be7c502448 | |||
| 92336f00bf | |||
| ed2a50d9a6 | |||
| 0acfdb9f78 | |||
| 96a8ea0241 | |||
| f20f2c9b7a | |||
| 7a97c38828 | |||
| 4885132565 | |||
| 8b46a0b7f1 | |||
| 1b6736ec6f | |||
| ddc1ce031e | |||
| 11d024bbaa | |||
| 43e23c16dc | |||
| f9c8e763ba | |||
| d7e1bb9f7c | |||
| ab93460a8b | |||
| 13d4552edc | |||
| 6667e307a2 | |||
| 7ac446e6a9 | |||
| eab9795bcc | |||
| 09bdd86b54 | |||
| 85cd74a51c | |||
| 314d2f2212 | |||
| fad25f3e11 | |||
| 2c3e3e27f7 | |||
| baeb0c4e7f | |||
| 2833517eef | |||
| abdc2bfdb3 | |||
| c3b834737f | |||
| 3c8e727b73 | |||
| 3a1e9f81f1 | |||
| 72c883f36c | |||
| 1b04d034cf | |||
| 2e45f5692a | |||
| c97b80bdfe | |||
| ae3ef9bc39 | |||
| db6715bec3 | |||
| da5d9e8a6a | |||
| 84b667ca7a | |||
| 29657106fc | |||
| 9c8860471e | |||
| 9b4e3f307e | |||
| 6fe37c3abf | |||
| 7f45493a37 | |||
| 891f6a5b5a | |||
| 7183f6b43d | |||
| d89bfeb441 | |||
| 9a0c6bed40 | |||
| d6ca535939 |
@@ -0,0 +1,22 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: true
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
issue_enrichment:
|
||||
planning:
|
||||
enabled: false
|
||||
@@ -0,0 +1,39 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: I found a defect
|
||||
title: ''
|
||||
labels: 'unconfirmed bug'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
> [!IMPORTANT]
|
||||
> If you have questions about llama-swap please post in the Q&A in Discussions. Use bug reports when you've found a defect and wish to discuss a fix.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Expected behaviour**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Operating system and version**
|
||||
|
||||
- OS: (linux, osx, windows, freebsd, etc)
|
||||
- GPUs: (list architecture)
|
||||
|
||||
**My Configuration**
|
||||
|
||||
```yaml
|
||||
# copy / paste your configuration here
|
||||
```
|
||||
|
||||
**Proxy Logs**
|
||||
|
||||
```
|
||||
# copy / paste from /logs
|
||||
```
|
||||
|
||||
**Upstream Logs**
|
||||
|
||||
```
|
||||
# copy/paste from /logs
|
||||
```
|
||||
@@ -0,0 +1,23 @@
|
||||
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "32 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
days-before-issue-stale: 14
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -0,0 +1,56 @@
|
||||
name: Validate JSON Schema
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
- "config.example.yaml"
|
||||
- ".github/workflows/config-schema.yml"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
- "config.example.yaml"
|
||||
- ".github/workflows/config-schema.yml"
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
validate-schema:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Validate JSON Schema
|
||||
run: |
|
||||
# Check if the file is valid JSON
|
||||
if ! jq empty config-schema.json 2>/dev/null; then
|
||||
echo "Error: config-schema.json is not valid JSON"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate that it's a valid JSON Schema
|
||||
# Check for required $schema field
|
||||
if ! jq -e '."$schema"' config-schema.json > /dev/null; then
|
||||
echo "Warning: config-schema.json should have a \$schema field"
|
||||
fi
|
||||
|
||||
# Check that it has either properties or definitions
|
||||
if ! jq -e '.properties or .definitions or ."$defs"' config-schema.json > /dev/null; then
|
||||
echo "Warning: JSON Schema should contain properties, definitions, or \$defs"
|
||||
fi
|
||||
|
||||
echo "✓ config-schema.json is valid"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Install check-jsonschema
|
||||
run: pip install check-jsonschema
|
||||
|
||||
- name: Validate config.example.yaml against schema
|
||||
run: check-jsonschema --schemafile config-schema.json config.example.yaml
|
||||
@@ -0,0 +1,73 @@
|
||||
name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on workflow file changes (without pushing)
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/containers.yml'
|
||||
- 'docker/build-container.sh'
|
||||
- 'docker/*.Containerfile'
|
||||
|
||||
# grant permissions on GITHUB_TOKEN to publish packages
|
||||
# ref: https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#publishing-a-package-using-an-action
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, cuda13, vulkan, cpu, musa, rocm]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
if: matrix.platform == 'rocm'
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
delete-untagged-containers:
|
||||
needs: build-and-push
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/delete-package-versions@v5
|
||||
with:
|
||||
package-name: 'llama-swap'
|
||||
package-type: 'container'
|
||||
delete-only-untagged-versions: 'true'
|
||||
@@ -0,0 +1,66 @@
|
||||
name: Windows CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: make simple-responder-windows
|
||||
|
||||
- name: Save Simple Responder
|
||||
# nothing new to save ... skip this step
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
id: save-simple-responder
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
shell: bash
|
||||
run: make test-all
|
||||
@@ -0,0 +1,69 @@
|
||||
name: Linux CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- "**/*.go"
|
||||
- "!cmd/**"
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- "Makefile"
|
||||
- ".github/workflows/go-ci.yml"
|
||||
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
paths:
|
||||
- "**/*.go"
|
||||
- "!cmd/**"
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- "Makefile"
|
||||
- ".github/workflows/go-ci.yml"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
# Only run in this linux based runner
|
||||
- name: Check Formatting
|
||||
run: |
|
||||
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||
gofmt -l . | grep -v 'event/.*_test.go'
|
||||
exit 1
|
||||
fi
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
run: make simple-responder
|
||||
|
||||
- name: Save Simple Responder
|
||||
# nothing new to save ... skip this step
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
id: save-simple-responder
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
@@ -3,7 +3,14 @@ name: goreleaser
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag version to release (e.g. v144)"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -12,22 +19,56 @@ jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
-
|
||||
name: Set up Go
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
-
|
||||
name: Run GoReleaser
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "24"
|
||||
- name: Install dependencies and build UI
|
||||
run: |
|
||||
cd ui-svelte
|
||||
npm ci
|
||||
npm run build
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
# either 'goreleaser' (default) or 'goreleaser-pro'
|
||||
distribution: goreleaser
|
||||
# 'latest', 'nightly', or a semver
|
||||
version: '~> v2'
|
||||
version: "~> v2"
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
trigger-tap-update:
|
||||
runs-on: ubuntu-latest
|
||||
needs: goreleaser
|
||||
steps:
|
||||
- name: "Resolve tag to dispatch"
|
||||
id: tag
|
||||
run: |
|
||||
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: "Trigger tap repository update"
|
||||
uses: peter-evans/repository-dispatch@v2
|
||||
with:
|
||||
token: ${{ secrets.TAP_REPO_PAT }}
|
||||
repository: mostlygeek/homebrew-llama-swap
|
||||
event-type: new-release
|
||||
client-payload: |
|
||||
{
|
||||
"release": {
|
||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
name: UI Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '24'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: ui-svelte/package-lock.json
|
||||
|
||||
- name: Run UI tests
|
||||
run: make test-ui
|
||||
@@ -0,0 +1,136 @@
|
||||
name: Build Unified Docker Image
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
llama_cpp_ref:
|
||||
description: "llama.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
whisper_ref:
|
||||
description: "whisper.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
sd_ref:
|
||||
description: "stable-diffusion.cpp commit hash, tag, or branch"
|
||||
required: false
|
||||
default: "master"
|
||||
ik_llama_ref:
|
||||
description: "ik_llama.cpp commit hash, tag, or branch (CUDA only)"
|
||||
required: false
|
||||
default: "main"
|
||||
llama_swap_version:
|
||||
description: "llama-swap version (e.g. v198, latest, main)"
|
||||
required: false
|
||||
default: "main"
|
||||
build_cuda:
|
||||
description: "Build CUDA image"
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
build_vulkan:
|
||||
description: "Build Vulkan image"
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
push_to_ghcr:
|
||||
description: "Push images to ghcr.io"
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- id: set-matrix
|
||||
run: |
|
||||
backends=()
|
||||
# schedule uses defaults (build both); workflow_dispatch respects inputs
|
||||
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_cuda }}" == "true" ]]; then
|
||||
backends+=("cuda")
|
||||
fi
|
||||
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ inputs.build_vulkan }}" == "true" ]]; then
|
||||
backends+=("vulkan")
|
||||
fi
|
||||
matrix=$(printf '%s\n' "${backends[@]}" | jq -R . | jq -sc .)
|
||||
echo "matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
|
||||
build:
|
||||
needs: setup
|
||||
if: ${{ needs.setup.outputs.matrix != '[]' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: ${{ fromJSON(needs.setup.outputs.matrix) }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
# On GitHub Actions runners, create a fresh builder.
|
||||
# When running locally under act, skip this and reuse the existing
|
||||
# llama-swap-builder (which has ccache warm) to avoid exhausting disk.
|
||||
- name: Set up Docker Buildx
|
||||
if: ${{ !env.ACT }}
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
if: ${{ !env.ACT }}
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build unified Docker image (${{ matrix.backend }})
|
||||
env:
|
||||
LLAMA_REF: ${{ inputs.llama_cpp_ref || 'master' }}
|
||||
WHISPER_REF: ${{ inputs.whisper_ref || 'master' }}
|
||||
SD_REF: ${{ inputs.sd_ref || 'master' }}
|
||||
IK_LLAMA_REF: ${{ inputs.ik_llama_ref || 'main' }}
|
||||
LS_VERSION: ${{ inputs.llama_swap_version || 'main' }}
|
||||
DOCKER_IMAGE_TAG: ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}
|
||||
# When running under act, use the local builder that has warm ccache.
|
||||
# On GitHub Actions, BUILDX_BUILDER is unset so docker uses the builder
|
||||
# created by setup-buildx-action above.
|
||||
BUILDX_BUILDER: ${{ env.ACT == 'true' && 'llama-swap-builder' || '' }}
|
||||
run: |
|
||||
chmod +x docker/unified/build-image.sh
|
||||
docker/unified/build-image.sh --${{ matrix.backend }}
|
||||
|
||||
- name: Push to GitHub Container Registry
|
||||
if: ${{ !env.ACT && (github.event_name == 'schedule' || inputs.push_to_ghcr == true) }}
|
||||
run: |
|
||||
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
|
||||
DATE_TAG=$(date -u +%Y-%m-%d)
|
||||
|
||||
docker push "${BASE_TAG}"
|
||||
docker tag "${BASE_TAG}" "${BASE_TAG}-${DATE_TAG}"
|
||||
docker push "${BASE_TAG}-${DATE_TAG}"
|
||||
|
||||
ROOTLESS_TAG="${BASE_TAG}-rootless"
|
||||
docker push "${ROOTLESS_TAG}"
|
||||
docker tag "${ROOTLESS_TAG}" "${ROOTLESS_TAG}-${DATE_TAG}"
|
||||
docker push "${ROOTLESS_TAG}-${DATE_TAG}"
|
||||
+3
-1
@@ -2,4 +2,6 @@
|
||||
.env
|
||||
build/
|
||||
dist/
|
||||
.vscode
|
||||
.vscode
|
||||
.DS_Store
|
||||
.dev/
|
||||
|
||||
+22
-1
@@ -6,6 +6,27 @@ builds:
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- freebsd
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm64
|
||||
ignore:
|
||||
- goos: freebsd
|
||||
goarch: arm64
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
|
||||
archives:
|
||||
- id: default
|
||||
formats:
|
||||
- tar.gz
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
builds_info:
|
||||
group: root
|
||||
owner: root
|
||||
format_overrides:
|
||||
# use zip format for windows
|
||||
- goos: windows
|
||||
formats:
|
||||
- zip
|
||||
@@ -0,0 +1,52 @@
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and svelt5 for UI (located in ui/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use the github CLI `gh` to create pull requests and work with github
|
||||
- Rules for creating pull requests:
|
||||
- keep them short and focused on changes.
|
||||
- never include a test plan
|
||||
- write the summary using the same style rules as commit message
|
||||
|
||||
## Testing
|
||||
|
||||
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
|
||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||
- Use `make test-ui` after making changes to the UI in ui-svelte/
|
||||
|
||||
### Commit message example format:
|
||||
|
||||
```
|
||||
proxy: add new feature
|
||||
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
- key change 1
|
||||
- key change 2
|
||||
- key change 3
|
||||
|
||||
fixes #123
|
||||
```
|
||||
|
||||
## Code Reviews
|
||||
|
||||
- use three levels High, Medium, Low severity
|
||||
- label each discovered issue with a label like H1, M2, L3 respectively
|
||||
- High severity are must fix issues (security, race conditions, critical bugs)
|
||||
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
|
||||
- Low severity are nice to have changes and nits
|
||||
- Include a suggestion with each discovered item
|
||||
- Limit your code review to three items with the highest priority first
|
||||
- Double check your discovered items and recommended remediations
|
||||
@@ -9,9 +9,6 @@ ifneq ($(shell git status --porcelain),)
|
||||
GIT_HASH := $(GIT_HASH)+
|
||||
endif
|
||||
|
||||
# Get the build number from the commit count on the main branch
|
||||
COMMIT_COUNT := $(shell git rev-list --count HEAD)
|
||||
|
||||
# Capture the current build date in RFC3339 format
|
||||
BUILD_DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
@@ -22,27 +19,59 @@ all: mac linux simple-responder
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
test:
|
||||
go test -short -v ./proxy
|
||||
proxy/ui_dist/placeholder.txt:
|
||||
mkdir -p proxy/ui_dist
|
||||
touch $@
|
||||
|
||||
test-all:
|
||||
go test -v ./proxy
|
||||
# use cached test results while developing
|
||||
test-dev: proxy/ui_dist/placeholder.txt
|
||||
go test -short ./proxy/...
|
||||
staticcheck ./proxy/... || true
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -count=1 ./proxy/...
|
||||
|
||||
# for CI - full test (takes longer)
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -race -count=1 ./proxy/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui-svelte && npm install
|
||||
|
||||
# build react UI
|
||||
ui: ui/node_modules
|
||||
cd ui-svelte && npm run build
|
||||
|
||||
# Build OSX binary
|
||||
mac:
|
||||
mac: ui
|
||||
@echo "Building Mac binary..."
|
||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=${COMMIT_COUNT} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||
|
||||
# Build Linux binary
|
||||
linux:
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=${COMMIT_COUNT} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
linux: linux-arm64 linux-amd64
|
||||
|
||||
linux-amd64: ui
|
||||
@echo "Building Linux AMD64 binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
|
||||
linux-arm64: ui
|
||||
@echo "Building Linux ARM64 binary..."
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||
|
||||
# Build Windows binary
|
||||
windows: ui
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 cmd/simple-responder/simple-responder.go
|
||||
|
||||
simple-responder-windows:
|
||||
@echo "Building simple responder for windows"
|
||||
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe cmd/simple-responder/simple-responder.go
|
||||
|
||||
# Ensure build directory exists
|
||||
$(BUILD_DIR):
|
||||
@@ -55,9 +84,22 @@ release:
|
||||
echo "Error: There are unstaged changes. Please commit or stash your changes before creating a release tag." >&2; \
|
||||
exit 1; \
|
||||
fi
|
||||
@echo "Creating release tag v$(COMMIT_COUNT)..."
|
||||
git tag v$(COMMIT_COUNT)
|
||||
git push origin v$(COMMIT_COUNT)
|
||||
|
||||
# Get the highest tag in v{number} format, increment it, and create a new tag
|
||||
@highest_tag=$$(git tag --sort=-v:refname | grep -E '^v[0-9]+$$' | head -n 1 || echo "v0"); \
|
||||
new_tag="v$$(( $${highest_tag#v} + 1 ))"; \
|
||||
echo "tagging new version: $$new_tag"; \
|
||||
git tag "$$new_tag";
|
||||
|
||||
GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
|
||||
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
|
||||
wol-proxy: $(BUILD_DIR)
|
||||
@echo "Building wol-proxy"
|
||||
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
||||
|
||||
test-ui:
|
||||
cd ui-svelte && npm ci && npm run check && npm test
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean osx linux
|
||||
.PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev test-ui wol-proxy
|
||||
.PHONE: linux linux-arm64 linux-amd64
|
||||
|
||||
@@ -1,140 +1,263 @@
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
# llama-swap
|
||||
|
||||

|
||||
Run multiple generative AI models on your machine and hot-swap between them on demand. llama-swap works with any OpenAI and Anthropic API compatible server and is used by thousands of people to power their local AI workflows.
|
||||
|
||||
# Introduction
|
||||
llama-swap is an OpenAI API compatible server that gives you complete control over how you use your hardware. It automatically swaps to the configuration of your choice for serving a model. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
||||
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
||||
|
||||
Features:
|
||||
## Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Single yaml configuration file
|
||||
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
- ✅ Use any local server that provides an OpenAI compatible API (llama.cpp, vllm, tabblyAPI, etc)
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
|
||||
- future proof, upgrade your inference servers at any time.
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/responses`
|
||||
- `v1/embeddings`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- `v1/audio/voices`
|
||||
- `v1/images/generations`
|
||||
- `v1/images/edits`
|
||||
- ✅ Anthropic API supported endpoints:
|
||||
- `v1/messages`
|
||||
- `v1/messages/count_tokens`
|
||||
- ✅ llama-server (llama.cpp) supported endpoints
|
||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||
- `/infill` - for code infilling
|
||||
- `/completion` - for completion endpoint
|
||||
- ✅ SDAPI via [stable-diffusion.cpp's server](https://github.com/leejet/stable-diffusion.cpp/tree/master/examples/server)
|
||||
- `/sdapi/v1/txt2img`
|
||||
- `/sdapi/v1/img2img`
|
||||
- `/sdapi/v1/loras` - requires `model` in request body to fetch the correct loras
|
||||
- ✅ llama-swap API
|
||||
- `/ui` - web UI
|
||||
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- `/log` - remote log monitoring
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||
- ✅ Customizable
|
||||
- Run concurrent models with a custom DSL swap matrix ([#643](https://github.com/mostlygeek/llama-swap/issues/643))
|
||||
- Automatic unloading of models after timeout by setting a `ttl`
|
||||
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
||||
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||
|
||||
## Releases
|
||||
### Web UI
|
||||
|
||||
Builds for Linux and OSX are available on the [Releases](https://github.com/mostlygeek/llama-swap/releases) page.
|
||||
llama-swap includes a real time web interface with a playground for testing out all sorts of local models:
|
||||
|
||||
### Building from source
|
||||
<img width="1125" height="876" alt="image" src="https://github.com/user-attachments/assets/8ee41947-97af-463d-b0f0-8e9c478fac07" />
|
||||
|
||||
1. Install golang for your system
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
View detailed token metrics:
|
||||
|
||||
## config.yaml
|
||||
<img width="1111" height="515" alt="image" src="https://github.com/user-attachments/assets/64bfb280-d7a3-4126-971a-a128fd40410c" />
|
||||
|
||||
llama-swap's configuration is purposefully simple.
|
||||
Inspect request and responses:
|
||||
|
||||
```yaml
|
||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||
# Default (and minimum) is 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
<img width="1111" height="720" alt="image" src="https://github.com/user-attachments/assets/24fe4aca-1448-4d7c-b9e8-a967589bda6c" />
|
||||
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
Manually load and unload models:
|
||||
|
||||
# where to reach the server started by cmd, make sure the ports match
|
||||
proxy: http://127.0.0.1:8999
|
||||
<img width="1109" height="719" alt="image" src="https://github.com/user-attachments/assets/02b1e1f2-abd0-4050-84ae-facd66ff01c4" />
|
||||
|
||||
# aliases names to use this model for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
Real time log streaming:
|
||||
|
||||
# check this path for an HTTP 200 OK before serving requests
|
||||
# default: /health to match llama.cpp
|
||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
||||
# until the model is ready
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# automatically unload the model after this many seconds
|
||||
# ttl values must be a value greater than 0
|
||||
# default: 0 = never unload model
|
||||
ttl: 60
|
||||
|
||||
"qwen":
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
```
|
||||
|
||||
More [examples](examples/README.md) are available for different use cases.
|
||||
<img width="1107" height="559" alt="image" src="https://github.com/user-attachments/assets/39669a10-cff2-409e-836a-5bad8bd0140c" />
|
||||
|
||||
## Installation
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
* _Note: Windows currently untested._
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
llama-swap can be installed in multiple ways
|
||||
|
||||
## Monitoring Logs
|
||||
1. Docker
|
||||
2. Homebrew (OSX and Linux)
|
||||
3. WinGet
|
||||
4. From release binaries
|
||||
5. From source
|
||||
|
||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
Of course, CLI access is also supported:
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
|
||||
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
|
||||
|
||||
```shell
|
||||
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
|
||||
# run with a custom configuration and models directory
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
|
||||
# configuration hot reload supported with a
|
||||
# directory volume mount
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
-v /path/to/config:/config \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
more examples
|
||||
</summary>
|
||||
|
||||
```shell
|
||||
# pull latest images per platform
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cpu
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:vulkan
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:intel
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:musa
|
||||
|
||||
# tagged llama-swap, platform and llama-server version images
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
||||
|
||||
# non-root cuda
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
|
||||
|
||||
```
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
# streams logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
</details>
|
||||
|
||||
### Homebrew Install (macOS/Linux)
|
||||
|
||||
```shell
|
||||
brew tap mostlygeek/llama-swap
|
||||
brew install llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
### WinGet Install (Windows)
|
||||
|
||||
> [!NOTE]
|
||||
> WinGet is maintained by community contributor [Dvd-Znf](https://github.com/Dvd-Znf) ([#327](https://github.com/mostlygeek/llama-swap/issues/327)). It is not an official part of llama-swap.
|
||||
|
||||
```shell
|
||||
# install
|
||||
C:\> winget install llama-swap
|
||||
|
||||
# upgrade
|
||||
C:\> winget upgrade llama-swap
|
||||
```
|
||||
|
||||
### Pre-built Binaries
|
||||
|
||||
Binaries are available on the [release](https://github.com/mostlygeek/llama-swap/releases) page for Linux, Mac, Windows and FreeBSD.
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Building requires Go and Node.js (for UI).
|
||||
1. `git clone https://github.com/mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. look in the `build/` subdirectory for the llama-swap binary
|
||||
|
||||
## Configuration
|
||||
|
||||
```yaml
|
||||
# minimum viable config.yaml
|
||||
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
|
||||
```
|
||||
|
||||
That's all you need to get started:
|
||||
|
||||
1. `models` - holds all model configurations
|
||||
2. `model1` - the ID used in API calls
|
||||
3. `cmd` - the command to run to start the server.
|
||||
4. `${PORT}` - an automatically assigned port number
|
||||
|
||||
Almost all configuration settings are optional and can be added one step at a time:
|
||||
|
||||
- Advanced features
|
||||
- `matrix` to run concurrent models with a custom swap logic DSL
|
||||
- `hooks` to run things on startup
|
||||
- `macros` reusable snippets
|
||||
- Model customization
|
||||
- `ttl` to automatically unload models
|
||||
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||
- `env` to pass custom environment variables to inference servers
|
||||
- `cmdStop` gracefully stop Docker/Podman containers
|
||||
- `useModelName` to override model names sent to upstream servers
|
||||
- `${PORT}` automatic port variables for dynamic port assignment
|
||||
- `filters` rewrite parts of requests before sending to the upstream server
|
||||
|
||||
See the [configuration documentation](docs/configuration.md) for all options.
|
||||
|
||||
## How does llama-swap work?
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, using a `matrix` allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## Reverse Proxy Configuration (nginx)
|
||||
|
||||
If you deploy llama-swap behind nginx, disable response buffering for streaming endpoints. By default, nginx buffers responses which breaks Server‑Sent Events (SSE) and streaming chat completion. ([#236](https://github.com/mostlygeek/llama-swap/issues/236))
|
||||
|
||||
Recommended nginx configuration snippets:
|
||||
|
||||
```nginx
|
||||
# SSE for UI events/logs
|
||||
location /api/events {
|
||||
proxy_pass http://your-llama-swap-backend;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Streaming chat completions (stream=true)
|
||||
location /v1/chat/completions {
|
||||
proxy_pass http://your-llama-swap-backend;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
```
|
||||
|
||||
As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. However, explicitly disabling `proxy_buffering` at your reverse proxy is still recommended for reliable streaming behavior.
|
||||
|
||||
## Monitoring Logs on the CLI
|
||||
|
||||
```sh
|
||||
# sends up to the last 10KB of logs
|
||||
$ curl http://host/logs
|
||||
|
||||
# streams combined logs
|
||||
curl -Ns http://host/logs/stream
|
||||
|
||||
# stream llama-swap's proxy status logs
|
||||
curl -Ns http://host/logs/stream/proxy
|
||||
|
||||
# stream logs from upstream processes that llama-swap loads
|
||||
curl -Ns http://host/logs/stream/upstream
|
||||
|
||||
# stream logs only from a specific model
|
||||
curl -Ns http://host/logs/stream/{model_id}
|
||||
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
# skips history and just streams new log entries
|
||||
# appending ?no-history will disable sending buffered history first
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
## Systemd Unit Files
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
|
||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||
|
||||
`/etc/systemd/system/llama-swap.service`
|
||||
```
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
After=network.target
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals for proper shutdown.
|
||||
|
||||
[Service]
|
||||
User=nobody
|
||||
## Star History
|
||||
|
||||
# set this to match your environment
|
||||
ExecStart=/path/to/llama-swap --config /path/to/llama-swap.config.yml
|
||||
> [!NOTE]
|
||||
> ⭐️ Star this project to help others discover it!
|
||||
|
||||
Restart=on-failure
|
||||
RestartSec=3
|
||||
StartLimitBurst=3
|
||||
StartLimitInterval=30
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# Replace ring.Ring with Efficient Circular Byte Buffer
|
||||
|
||||
## Overview
|
||||
|
||||
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
|
||||
|
||||
## Current Issues
|
||||
|
||||
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
|
||||
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
|
||||
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
|
||||
4. Linked list structure has poor cache locality and pointer overhead
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### New CircularBuffer Type
|
||||
|
||||
Create a simple circular byte buffer with:
|
||||
- Single pre-allocated `[]byte` of fixed capacity (10KB)
|
||||
- `head` and `size` integers to track write position and data length
|
||||
- No per-write allocations
|
||||
|
||||
### API Requirements
|
||||
|
||||
The new buffer must support:
|
||||
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
|
||||
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
|
||||
|
||||
### Implementation Details
|
||||
|
||||
```go
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
```
|
||||
|
||||
**Write logic:**
|
||||
- If `len(p) >= capacity`: just keep the last `capacity` bytes
|
||||
- Otherwise: write bytes at `head`, wrapping around if needed
|
||||
- Update `head` and `size` accordingly
|
||||
- Data is copied into the internal buffer (not stored by reference)
|
||||
|
||||
**GetHistory logic:**
|
||||
- Calculate start position: `(head - size + cap) % cap`
|
||||
- If not wrapped: single slice copy
|
||||
- If wrapped: two copies (end of buffer + beginning)
|
||||
- Returns a **new slice** (copy), not a view into internal buffer
|
||||
|
||||
### Immutability Guarantees (must preserve)
|
||||
|
||||
Per existing tests:
|
||||
1. Modifying input `[]byte` after `Write()` must not affect stored data
|
||||
2. `GetHistory()` returns independent copy - modifications don't affect buffer
|
||||
|
||||
## Files to Modify
|
||||
|
||||
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
|
||||
|
||||
## Testing Plan
|
||||
|
||||
Existing tests in `logMonitor_test.go` should continue to pass:
|
||||
- `TestLogMonitor` - Basic write/read and subscriber notification
|
||||
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
|
||||
- `TestWrite_LogTimeFormat` - Timestamp formatting
|
||||
|
||||
Add new tests:
|
||||
- Test buffer wrap-around behavior
|
||||
- Test large writes that exceed buffer capacity
|
||||
- Test exact capacity boundary conditions
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Create `circularBuffer` struct in `logMonitor.go`
|
||||
- [ ] Implement `Write()` method for circular buffer
|
||||
- [ ] Implement `GetHistory()` method for circular buffer
|
||||
- [ ] Update `LogMonitor` struct to use new buffer
|
||||
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
|
||||
- [ ] Update `LogMonitor.Write()` to use new buffer
|
||||
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
|
||||
- [ ] Remove `"container/ring"` import
|
||||
- [ ] Run `make test-dev` to verify existing tests pass
|
||||
- [ ] Add wrap-around test case
|
||||
- [ ] Run `make test-all` for final validation
|
||||
@@ -0,0 +1,183 @@
|
||||
# Improve Testability (#655)
|
||||
|
||||
## Current Pain Points
|
||||
|
||||
1. **Tests bypass config loading** - ~80% of tests build `config.Config` structs directly, skipping YAML parsing, env var substitution, macro expansion, and `${PORT}` assignment. Config bugs in those paths go untested.
|
||||
|
||||
2. **simple-responder is everywhere** - Every proxy/routing test launches a real subprocess, waits for health checks (~healthCheckTimeout: 15), and manages process lifecycle just to test HTTP routing. Most of that overhead is wasted.
|
||||
|
||||
3. **Port counter is fragile** - A global `nextTestPort` counter starting at 12000 with a mutex. Parallel tests or leftover processes can collide.
|
||||
|
||||
## Stages
|
||||
|
||||
### Stage 1: YAML-based test config helper
|
||||
|
||||
**Goal:** Tests go through the real `LoadConfigFromReader` path instead of hand-building structs.
|
||||
|
||||
**Effort:** Low | **Impact:** Config bugs caught earlier | **Risk:** None
|
||||
|
||||
Create a test helper in `proxy/helpers_test.go`:
|
||||
|
||||
```go
|
||||
// testConfigFromYAML substitutes simple-responder paths and loads through
|
||||
// the real config pipeline (env vars, macros, port assignment, etc.)
|
||||
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||
t.Helper()
|
||||
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||
require.NoError(t, err)
|
||||
return cfg
|
||||
}
|
||||
```
|
||||
|
||||
Tests would then look like:
|
||||
|
||||
```go
|
||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||
config := testConfigFromYAML(t, `
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
models:
|
||||
model1:
|
||||
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
|
||||
model2:
|
||||
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model2
|
||||
`)
|
||||
proxy := New(config)
|
||||
// ... same assertions
|
||||
}
|
||||
```
|
||||
|
||||
**Why this stage first:** Zero production code changes. Pure test-side refactoring. Can be done incrementally - migrate tests one at a time. Each migrated test now validates the full config pipeline.
|
||||
|
||||
**Scope:** ~20-30 tests in `proxymanager_test.go`, `processgroup_test.go`, `peerproxy_test.go`.
|
||||
|
||||
### Stage 2: Injected test handler (eliminate simple-responder for routing tests)
|
||||
|
||||
**Goal:** Replace simple-responder subprocess launches with an injected `http.Handler` for tests that don't specifically test process lifecycle.
|
||||
|
||||
**Effort:** Medium | **Impact:** 10-100x faster routing tests | **Risk:** Low (additive, no existing code broken)
|
||||
|
||||
Add a `testHandler http.Handler` field to `Process`. When set, `ProxyRequest` delegates directly to this handler instead of going through the reverse proxy. No subprocess, no health checks, no TCP roundtrip.
|
||||
|
||||
**2a. Add testHandler to Process:**
|
||||
|
||||
```go
|
||||
// In Process struct (process.go):
|
||||
testHandler http.Handler // set only in tests; bypasses subprocess and reverse proxy
|
||||
```
|
||||
|
||||
In `Process.Start()`, skip subprocess + health check when handler is set:
|
||||
|
||||
```go
|
||||
func (p *Process) start() error {
|
||||
if p.testHandler != nil {
|
||||
p.setState(StateReady)
|
||||
return nil
|
||||
}
|
||||
// existing subprocess logic...
|
||||
}
|
||||
```
|
||||
|
||||
In `Process.ProxyRequest()`, delegate directly to the handler:
|
||||
|
||||
```go
|
||||
// Before the reverseProxy.ServeHTTP call:
|
||||
if p.testHandler != nil {
|
||||
p.testHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
**2b. Test helper to create the handler:**
|
||||
|
||||
```go
|
||||
// newTestHandler returns an http.Handler that mimics llama.cpp's API
|
||||
// (same endpoints as simple-responder).
|
||||
func newTestHandler(respond string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { ... })
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { ... })
|
||||
// ... other endpoints
|
||||
return mux
|
||||
}
|
||||
```
|
||||
|
||||
Tests for routing/auth/CORS/streaming then become:
|
||||
|
||||
```go
|
||||
func TestProxyManager_AuthRequired(t *testing.T) {
|
||||
handler := newTestHandler("model1")
|
||||
|
||||
config := testConfigFromYAML(t, `
|
||||
healthCheckTimeout: 15
|
||||
logLevel: error
|
||||
requiredAPIKeys: [test-key]
|
||||
models:
|
||||
model1:
|
||||
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
|
||||
`)
|
||||
pm := NewProxyManager(config)
|
||||
// inject handler — skips subprocess, health check, port allocation
|
||||
pm.processGroups["model1"].process.testHandler = handler
|
||||
}
|
||||
```
|
||||
|
||||
**Why this matters:** The handler is called directly in-process. No subprocess spawn, no health check timeout, no port allocation, no TCP roundtrip, no reverse proxy overhead. Routing tests go from ~100ms each (process startup + health check) to ~1ms. Unlike an `httptest.Server` approach, there are zero network hops.
|
||||
|
||||
**Why not blank-cmd + proxy URL:** A blank `cmd` with a `proxy` field pointing at `httptest.Server` still requires a real TCP roundtrip through the reverse proxy and introduces "external process" semantics to the config schema. Injecting the handler directly keeps it purely a test concern with no config changes.
|
||||
|
||||
**Scope:** Most tests in `proxymanager_test.go` (auth, CORS, model listing, streaming, peer proxy), `peerproxy_test.go`, `metrics_monitor_test.go`.
|
||||
|
||||
### Stage 3: Migrate tests incrementally
|
||||
|
||||
**Goal:** Convert existing tests to use the Stage 1 + Stage 2 helpers.
|
||||
|
||||
**Effort:** Medium | **Impact:** Cleaner, more reliable tests | **Risk:** None
|
||||
|
||||
Priority order:
|
||||
1. `proxymanager_test.go` routing tests (highest count, most repetition)
|
||||
2. `peerproxy_test.go` (straightforward, all HTTP routing)
|
||||
3. `metrics_monitor_test.go` (capture logic doesn't need real processes)
|
||||
4. `processgroup_test.go` swap tests (keep simple-responder for actual swap lifecycle tests)
|
||||
|
||||
Tests that **must keep simple-responder:**
|
||||
- Process lifecycle: start/stop, SIGKILL, SIGTERM, TTL expiry, health check failures, failed start counting
|
||||
- ProcessGroup swap concurrency (the port-collision test in `TestProcessGroup_ProxyRequestSwapIsTrueParallel`)
|
||||
|
||||
**Scope:** ~60-70% of tests can drop simple-responder.
|
||||
|
||||
### Stage 4 (optional): Process interface for ProcessGroup
|
||||
|
||||
**Goal:** Enable pure unit tests of ProcessGroup's swap/exclusive/concurrency logic without any HTTP server at all.
|
||||
|
||||
**Effort:** High | **Impact:** Pure unit tests possible | **Risk:** Medium (refactor core code)
|
||||
|
||||
```go
|
||||
type ProcessController interface {
|
||||
Start() error
|
||||
Stop(StopStrategy)
|
||||
ProxyRequest(http.ResponseWriter, *http.Request) error
|
||||
CurrentState() ProcessState
|
||||
ID() string
|
||||
SetState(ProcessState) // for test setup
|
||||
}
|
||||
```
|
||||
|
||||
This requires:
|
||||
- Extracting the interface
|
||||
- A `MockProcess` implementation
|
||||
- Refactoring `ProcessGroup` to use the interface instead of `*Process`
|
||||
|
||||
**Recommendation:** Only do this if ProcessGroup grows significantly more complex. Stages 1-3 give 80% of the benefit for 20% of the effort.
|
||||
|
||||
## Effort/Impact Summary
|
||||
|
||||
| Stage | Effort | Impact | Risk |
|
||||
|-------|--------|--------|------|
|
||||
| 1. YAML config helper | Low | Config bugs caught earlier | None |
|
||||
| 2. Injected test handler | Medium | 10-100x faster routing tests | Low |
|
||||
| 3. Migrate tests | Medium | Cleaner, more reliable tests | None |
|
||||
| 4. Process interface | High | Pure unit tests possible | Medium |
|
||||
|
||||
**Recommended approach:** Do stages 1-3 in order. Each stage is independently valuable and can ship on its own. Stage 4 is deferred unless there's a specific need.
|
||||
@@ -0,0 +1,292 @@
|
||||
# Add Model Metadata Support with Typed Macros
|
||||
|
||||
## Overview
|
||||
|
||||
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
|
||||
|
||||
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### 1. Enhanced Macro System
|
||||
|
||||
**Current State:**
|
||||
|
||||
- Macros are defined as `map[string]string` at both global and model levels
|
||||
- Only string substitution is supported
|
||||
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Change `MacroList` type from `map[string]string` to `map[string]any`
|
||||
- Support scalar types: `string`, `int`, `float64`, `bool`
|
||||
- Implement type-preserving macro substitution:
|
||||
- Direct macro usage (`key: ${macro}`) preserves the macro's type
|
||||
- Interpolated usage (`key: "text ${macro}"`) converts to string
|
||||
- Add validation to ensure macro values are scalar types only
|
||||
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
|
||||
|
||||
**Implementation Details:**
|
||||
|
||||
- Create a generic helper function to perform macro substitution that:
|
||||
- Takes a value of type `any`
|
||||
- Recursively processes maps, slices, and scalar values
|
||||
- Replaces `${macro_name}` patterns with macro values
|
||||
- Preserves types for direct substitution
|
||||
- Converts to strings for interpolated substitution
|
||||
- Update `validateMacro()` function to accept `any` type and validate scalar types
|
||||
- Maintain backward compatibility with existing string-only macros
|
||||
|
||||
### 2. Metadata Field in ModelConfig
|
||||
|
||||
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Add `Metadata map[string]any` field to `ModelConfig` struct
|
||||
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
|
||||
- Apply macro substitution to metadata values during config loading
|
||||
|
||||
**Schema Requirements:**
|
||||
|
||||
- Metadata is optional (default: empty/nil map)
|
||||
- Supports nested structures (objects within objects, arrays, etc.)
|
||||
- All string values within metadata undergo macro substitution
|
||||
- Type preservation rules apply as described above
|
||||
|
||||
### 3. Macro Substitution in Metadata
|
||||
|
||||
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
|
||||
|
||||
**Process Flow:**
|
||||
|
||||
1. After loading YAML configuration
|
||||
2. After model-level and global macro merging
|
||||
3. Apply macro substitution to `ModelConfig.Metadata` field
|
||||
4. Use the same merged macros available to `cmd`, `proxy`, etc.
|
||||
5. Process recursively through all nested structures
|
||||
|
||||
**Substitution Rules:**
|
||||
|
||||
- `port: ${PORT}` → keeps integer type from PORT macro
|
||||
- `temperature: ${temp}` → keeps float type from temp macro
|
||||
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
|
||||
- Arrays and nested objects are processed recursively
|
||||
- Unknown macros should cause configuration load error (consistent with existing behavior)
|
||||
|
||||
### 4. API Response Updates
|
||||
|
||||
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
|
||||
|
||||
**Current Behavior:**
|
||||
|
||||
- Returns model records with: `id`, `object`, `created`, `owned_by`
|
||||
- Optionally includes: `name`, `description`
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Add metadata to each model record under the key `llamaswap_meta`
|
||||
- Only include `llamaswap_meta` if metadata is non-empty
|
||||
- Preserve all types when marshaling to JSON
|
||||
- Maintain existing sorting by model ID
|
||||
|
||||
**Example Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "llama",
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "llama-swap",
|
||||
"name": "llama 3.1 8B",
|
||||
"description": "A small but capable model",
|
||||
"llamaswap_meta": {
|
||||
"port": 10001,
|
||||
"temperature": 0.7,
|
||||
"note": "The llama is running on port 10001 temp=0.7, context=16384",
|
||||
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
|
||||
"an_obj": {
|
||||
"a": "1",
|
||||
"b": 2,
|
||||
"c": [0.7, false, "model: llama"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Validation and Error Handling
|
||||
|
||||
**Macro Validation:**
|
||||
|
||||
- Extend `validateMacro()` to accept values of type `any`
|
||||
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
|
||||
- Reject complex types (maps, slices, structs) as macro values
|
||||
- Maintain existing validation for macro names and lengths
|
||||
|
||||
**Configuration Loading:**
|
||||
|
||||
- Fail fast if unknown macros are found in metadata
|
||||
- Provide clear error messages indicating which model and field contains errors
|
||||
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
|
||||
|
||||
## Testing Plan
|
||||
|
||||
### Test 1: Model-Level Macros with Different Types
|
||||
|
||||
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Define model with macros of each scalar type
|
||||
- Verify metadata correctly substitutes and preserves types
|
||||
- Test direct substitution (`port: ${PORT}`)
|
||||
- Test string interpolation (`note: "Port is ${PORT}"`)
|
||||
- Verify nested objects and arrays work correctly
|
||||
|
||||
### Test 2: Global and Model Macro Precedence
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Define same macro at global and model level with different types
|
||||
- Verify model-level macro takes precedence
|
||||
- Test metadata uses correct macro value
|
||||
- Verify type is preserved from the winning macro
|
||||
|
||||
### Test 3: Macro Validation
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Test that complex types (maps, arrays) are rejected as macro values
|
||||
- Verify error message includes: macro name and type that was rejected
|
||||
- Test that scalar types (string, int, float, bool) are accepted
|
||||
- Each type should load without error
|
||||
- Test macro name validation still works with `any` types
|
||||
- Invalid characters, reserved names, length limits should still be enforced
|
||||
|
||||
### Test 4: Metadata in API Response
|
||||
|
||||
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||
|
||||
**Existing Test:** `TestProxyManager_ListModelsHandler`
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Model with metadata → verify `llamaswap_meta` key appears
|
||||
- Model without metadata → verify `llamaswap_meta` key is absent
|
||||
- Verify all types are correctly marshaled to JSON
|
||||
- Verify nested structures are preserved
|
||||
- Verify macro substitution has occurred before serialization
|
||||
|
||||
### Test 5: Unknown Macros in Metadata
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Use undefined macro in metadata
|
||||
- Verify configuration loading fails with clear error
|
||||
- Error should indicate model name and that macro is undefined
|
||||
|
||||
### Test 6: Recursive Substitution
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Metadata with deeply nested structures
|
||||
- Arrays containing objects with macros
|
||||
- Objects containing arrays with macros
|
||||
- Mixed string interpolation and direct substitution at various nesting levels
|
||||
|
||||
## Checklist
|
||||
|
||||
### Configuration Schema Changes
|
||||
|
||||
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
|
||||
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
|
||||
- [x] Update `validateMacro()` function signature to accept `any` type for values
|
||||
- [x] Add validation logic to ensure macro values are scalar types only
|
||||
|
||||
### Macro Substitution Logic
|
||||
|
||||
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
|
||||
- [x] Implement type-preserving direct substitution logic
|
||||
- [x] Implement string interpolation with type conversion
|
||||
- [x] Handle maps: recursively process all values
|
||||
- [x] Handle slices: recursively process all elements
|
||||
- [x] Handle scalar types: perform string-based macro substitution if value is string
|
||||
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
|
||||
- [x] Update existing macro substitution calls to use merged macros with correct types
|
||||
|
||||
### API Response Changes
|
||||
|
||||
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
|
||||
- [x] Add `llamaswap_meta` field to model records when metadata exists
|
||||
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
|
||||
- [x] Verify JSON marshaling preserves all types correctly
|
||||
|
||||
### Testing - Config Package
|
||||
|
||||
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
### Testing - Model Config Package
|
||||
|
||||
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
|
||||
- [x] Test metadata with various scalar types
|
||||
- [x] Test metadata with nested objects and arrays
|
||||
|
||||
### Testing - Proxy Manager
|
||||
|
||||
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||
- [x] Add test case for model with metadata
|
||||
- [x] Add test case for model without metadata
|
||||
- [x] Verify `llamaswap_meta` key presence/absence
|
||||
- [x] Verify type preservation in JSON output
|
||||
- [x] Verify macro substitution has occurred
|
||||
|
||||
### Documentation
|
||||
|
||||
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
|
||||
- [x] No additional documentation needed per project instructions
|
||||
|
||||
## Known Issues and Considerations
|
||||
|
||||
### Inconsistencies
|
||||
|
||||
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
|
||||
|
||||
### Design Decisions
|
||||
|
||||
1. **Why `llamaswap_meta` instead of merging into record?**
|
||||
|
||||
- Avoids potential collisions with OpenAI API standard fields
|
||||
- Makes it clear this is llama-swap specific metadata
|
||||
- Easier for clients to distinguish standard vs. custom fields
|
||||
|
||||
2. **Why support nested structures?**
|
||||
|
||||
- Provides maximum flexibility for users
|
||||
- Aligns with the schemaless design principle
|
||||
- Example config already demonstrates this capability
|
||||
|
||||
3. **Why validate macro types?**
|
||||
- Prevents confusing behavior (e.g., substituting a map)
|
||||
- Makes configuration errors explicit at load time
|
||||
- Simpler implementation and testing
|
||||
@@ -0,0 +1,397 @@
|
||||
# Improve macro-in-macro support
|
||||
|
||||
**Status: COMPLETED ✅**
|
||||
|
||||
## Title
|
||||
|
||||
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
|
||||
|
||||
## Overview
|
||||
|
||||
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
|
||||
|
||||
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
|
||||
|
||||
**Outcomes:**
|
||||
- Macros can reference other macros defined earlier in the config
|
||||
- Macro substitution is deterministic and order-dependent
|
||||
- Single-pass substitution prevents circular dependencies
|
||||
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
|
||||
- All existing tests pass
|
||||
- New tests validate substitution order and self-reference detection
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### 1. YAML Parsing Strategy
|
||||
- **Continue using:** `gopkg.in/yaml.v3` (current library)
|
||||
- **Use:** `yaml.Node` for ordered parsing of macros
|
||||
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
|
||||
|
||||
### 2. Data Structure Changes
|
||||
|
||||
#### Current Implementation (config.go:19)
|
||||
```go
|
||||
type MacroList map[string]any
|
||||
```
|
||||
|
||||
#### New Implementation
|
||||
```go
|
||||
type MacroList []MacroEntry
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
```
|
||||
|
||||
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
|
||||
|
||||
### 3. Macro Substitution Order Rules
|
||||
|
||||
The substitution must follow this hierarchy (from most specific to least):
|
||||
|
||||
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
|
||||
2. **Model-level macros** (middle): Defined in specific model config, overrides global
|
||||
3. **Global macros** (first): Defined at config root level
|
||||
|
||||
Within each level, macros are substituted in **reverse definition order** (LIFO):
|
||||
- The last macro defined is substituted first
|
||||
- This allows later macros to reference earlier ones
|
||||
- Single-pass substitution prevents circular dependencies
|
||||
|
||||
### 4. Macro Reference Rules
|
||||
|
||||
**Allowed:**
|
||||
- Macro can reference any macro defined **before** it (earlier in the file)
|
||||
- Model macros can reference global macros
|
||||
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
|
||||
|
||||
**Prohibited:**
|
||||
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
|
||||
- Macro cannot reference macros defined **after** it
|
||||
- No circular references (prevented by single-pass, ordered substitution)
|
||||
|
||||
### 5. Validation Requirements
|
||||
|
||||
Add validation to detect:
|
||||
- **Self-references:** Macro value contains reference to its own name
|
||||
- **Unknown macros:** After substitution, any remaining `${...}` references
|
||||
|
||||
Error messages should be clear:
|
||||
```
|
||||
macro 'foo' contains self-reference
|
||||
unknown macro '${bar}' in model.cmd
|
||||
```
|
||||
|
||||
### 6. Implementation Changes
|
||||
|
||||
#### Files to Modify
|
||||
|
||||
1. **[proxy/config/config.go](proxy/config/config.go)**
|
||||
- Line 19: Change `MacroList` type definition
|
||||
- Line 69: Update `Macros MacroList` field
|
||||
- Line 153-157: Update macro validation loop to work with ordered structure
|
||||
- Line 175-188: Update model-level macro validation
|
||||
- Line 181-188: **NEW** Implement proper macro merging respecting order
|
||||
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
|
||||
- Line 389-415: Update `validateMacro` to detect self-references
|
||||
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
|
||||
|
||||
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
|
||||
- Line 33: Update `Macros MacroList` field type
|
||||
|
||||
3. **All test files**
|
||||
- Update test fixtures to use ordered macro definitions
|
||||
- Ensure tests specify macro order explicitly
|
||||
|
||||
#### Core Algorithm
|
||||
|
||||
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
|
||||
|
||||
```go
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
|
||||
|
||||
// Add global macros first
|
||||
for _, entry := range config.Macros {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
|
||||
// Add model macros (can override global)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Add reserved MODEL_ID macro at the end
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Check if PORT macro is needed
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro to the end (highest priority)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Add this new helper function to replace `substituteMetadataMacros`:
|
||||
|
||||
```go
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7. Self-Reference Detection
|
||||
|
||||
Add to `validateMacro` function:
|
||||
|
||||
```go
|
||||
func validateMacro(name string, value any) error {
|
||||
// ... existing validation ...
|
||||
|
||||
// Check for self-reference
|
||||
if str, ok := value.(string); ok {
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(str, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Plan
|
||||
|
||||
### 1. Migration Tests
|
||||
- **Test:** All existing macro tests still pass after YAML library migration
|
||||
- **Files:** All `*_test.go` files with macro tests
|
||||
|
||||
### 2. Macro Order Tests
|
||||
|
||||
#### Test: Macro-in-macro substitution order
|
||||
```yaml
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: "echo ${B}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
|
||||
|
||||
#### Test: LIFO substitution order
|
||||
```yaml
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: "load ${full}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
|
||||
|
||||
#### Test: Model macro overrides global
|
||||
```yaml
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: "echo ${msg}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
|
||||
|
||||
### 3. Reserved Macro Tests
|
||||
|
||||
#### Test: MODEL_ID substituted in macro
|
||||
```yaml
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: "${podman-llama} -m model.gguf"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
|
||||
|
||||
### 4. Error Detection Tests
|
||||
|
||||
#### Test: Self-reference detection
|
||||
```yaml
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
```
|
||||
**Expected:** Error: `macro 'recursive' contains self-reference`
|
||||
|
||||
#### Test: Undefined macro reference
|
||||
```yaml
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
```
|
||||
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
|
||||
|
||||
### 5. Regression Tests
|
||||
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
|
||||
- Ensure all pass without modification (except test fixtures if needed)
|
||||
|
||||
## Checklist
|
||||
|
||||
### Phase 1: Data Structure Changes
|
||||
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
|
||||
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
|
||||
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
|
||||
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
|
||||
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
|
||||
- [ ] Implement helper functions:
|
||||
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
|
||||
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
|
||||
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
|
||||
|
||||
### Phase 2: Macro Validation Updates
|
||||
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
|
||||
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
|
||||
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
|
||||
- [ ] Test self-reference detection with new test case
|
||||
|
||||
### Phase 3: Macro Substitution Algorithm
|
||||
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
|
||||
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
|
||||
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
|
||||
- [ ] Substitute in metadata within same loop
|
||||
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
|
||||
- [ ] Ensure `PORT` is added after port assignment (if needed)
|
||||
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
|
||||
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
|
||||
|
||||
### Phase 4: Testing
|
||||
- [ ] Run `make test-dev` - fix any static checking errors
|
||||
- [ ] Add test: macro-in-macro basic substitution
|
||||
- [ ] Add test: LIFO substitution order with 3+ macro levels
|
||||
- [ ] Add test: MODEL_ID in global macro used by model
|
||||
- [ ] Add test: PORT in global macro used by model
|
||||
- [ ] Add test: model macro overrides global macro in substitution
|
||||
- [ ] Add test: self-reference detection error
|
||||
- [ ] Add test: undefined macro reference error
|
||||
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
|
||||
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
|
||||
|
||||
### Phase 5: Documentation
|
||||
- [ ] Update plan status in this file (mark completed)
|
||||
- [ ] Update CLAUDE.md if macro behavior needs documentation
|
||||
- [ ] Verify no new error messages need user documentation
|
||||
|
||||
## Bug Example (Original Issue)
|
||||
|
||||
```yaml
|
||||
macros:
|
||||
"podman-llama": >
|
||||
podman run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
|
||||
ghcr.io/ggml-org/llama.cpp:server-cuda
|
||||
|
||||
"standard-options": >
|
||||
--no-mmap --jinja
|
||||
|
||||
"kv8": >
|
||||
-fa on -ctk q8_0 -ctv q8_0
|
||||
```
|
||||
|
||||
**Current Bug:**
|
||||
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
|
||||
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
|
||||
|
||||
**After Fix:**
|
||||
- Macros substituted in LIFO order: `kv8` → `standard-options` → `podman-llama`
|
||||
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
|
||||
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
// created for issue: #252 https://github.com/mostlygeek/llama-swap/issues/252
|
||||
// this simple benchmark tool sends a lot of small chat completion requests to llama-swap
|
||||
// to make sure all the requests are accounted for.
|
||||
//
|
||||
// requests can be sent in parallel, and the tool will report the results.
|
||||
// usage: go run main.go -baseurl http://localhost:8080/v1 -model llama3 -requests 1000 -par 5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// ----- CLI arguments ----------------------------------------------------
|
||||
var (
|
||||
baseurl string
|
||||
modelName string
|
||||
totalRequests int
|
||||
parallelization int
|
||||
)
|
||||
|
||||
flag.StringVar(&baseurl, "baseurl", "http://localhost:8080/v1", "Base URL of the API (e.g., https://api.example.com)")
|
||||
flag.StringVar(&modelName, "model", "", "Model name to use")
|
||||
flag.IntVar(&totalRequests, "requests", 1, "Total number of requests to send")
|
||||
flag.IntVar(¶llelization, "par", 1, "Maximum number of concurrent requests")
|
||||
flag.Parse()
|
||||
|
||||
if baseurl == "" || modelName == "" {
|
||||
fmt.Println("Error: both -baseurl and -model are required.")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
if totalRequests <= 0 {
|
||||
fmt.Println("Error: -requests must be greater than 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
if parallelization <= 0 {
|
||||
fmt.Println("Error: -parallelization must be greater than 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// ----- HTTP client -------------------------------------------------------
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// ----- Tracking response codes -------------------------------------------
|
||||
statusCounts := make(map[int]int) // map[statusCode]count
|
||||
var mu sync.Mutex // protects statusCounts
|
||||
|
||||
// ----- Request queue (buffered channel) ----------------------------------
|
||||
requests := make(chan int, 10) // Buffered channel with capacity 10
|
||||
|
||||
// Goroutine to fill the request queue
|
||||
go func() {
|
||||
for i := 0; i < totalRequests; i++ {
|
||||
requests <- i + 1
|
||||
}
|
||||
close(requests)
|
||||
}()
|
||||
|
||||
// ----- Worker pool -------------------------------------------------------
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < parallelization; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for reqID := range requests {
|
||||
// Build request payload as a single line JSON string
|
||||
payload := `{"model":"` + modelName + `","max_tokens":100,"stream":false,"messages":[{"role":"user","content":"write a snake game in python"}]}`
|
||||
|
||||
// Send POST request
|
||||
req, err := http.NewRequest(http.MethodPost,
|
||||
fmt.Sprintf("%s/chat/completions", baseurl),
|
||||
bytes.NewReader([]byte(payload)))
|
||||
if err != nil {
|
||||
log.Printf("[worker %d][req %d] request creation error: %v", workerID, reqID, err)
|
||||
mu.Lock()
|
||||
statusCounts[-1]++
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[worker %d][req %d] HTTP request error: %v", workerID, reqID, err)
|
||||
mu.Lock()
|
||||
statusCounts[-1]++
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
// Record status code
|
||||
mu.Lock()
|
||||
statusCounts[resp.StatusCode]++
|
||||
mu.Unlock()
|
||||
}
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
// ----- Status ticker (prints every second) -------------------------------
|
||||
done := make(chan struct{})
|
||||
tickerDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
startTime := time.Now()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mu.Lock()
|
||||
// Compute how many requests have completed so far
|
||||
completed := 0
|
||||
for _, cnt := range statusCounts {
|
||||
completed += cnt
|
||||
}
|
||||
// Calculate duration and progress
|
||||
duration := time.Since(startTime)
|
||||
progress := completed * 100 / totalRequests
|
||||
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, progress)
|
||||
mu.Unlock()
|
||||
case <-done:
|
||||
duration := time.Since(startTime)
|
||||
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, 100)
|
||||
close(tickerDone)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for all workers to finish
|
||||
wg.Wait()
|
||||
close(done) // stops the status-update goroutine
|
||||
<-tickerDone // give ticker time to finish / print
|
||||
|
||||
// ----- Summary ------------------------------------------------------------
|
||||
fmt.Println("\n\n=== HTTP response code summary ===")
|
||||
mu.Lock()
|
||||
for code, cnt := range statusCounts {
|
||||
if code == -1 {
|
||||
fmt.Printf("Client-side errors (no HTTP response): %d\n", cnt)
|
||||
} else {
|
||||
fmt.Printf("%d : %d\n", code, cnt)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
**
|
||||
Test how exec.Cmd.CommandContext behaves under certain conditions:*
|
||||
|
||||
- process is killed externally, what happens with cmd.Wait() *
|
||||
✔︎ it returns. catches crashes.*
|
||||
|
||||
- process ignores SIGTERM*
|
||||
✔︎ `kill()` is called after cmd.WaitDelay*
|
||||
|
||||
- this process exits, what happens with children (kill -9 <this process' pid>)*
|
||||
x they stick around. have to be manually killed.*
|
||||
|
||||
- .WithTimeout()'s cancel is called *
|
||||
✔︎ process is killed after it ignores sigterm, cmd.Wait() catches it.*
|
||||
|
||||
- parent receives SIGINT/SIGTERM, what happens
|
||||
✔︎ waits for child process to exit, then exits gracefully.
|
||||
*/
|
||||
func main() {
|
||||
|
||||
// swap between these to use kill -9 <pid> on the cli to sim external crash
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
//ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
//cmd := exec.CommandContext(ctx, "sleep", "1")
|
||||
cmd := exec.CommandContext(ctx,
|
||||
"../../build/simple-responder_darwin_arm64",
|
||||
//"-ignore-sig-term", /* so it doesn't exit on receiving SIGTERM, test cmd.WaitTimeout */
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
// set a wait delay before signing sig kill
|
||||
cmd.WaitDelay = 500 * time.Millisecond
|
||||
cmd.Cancel = func() error {
|
||||
fmt.Println("✔︎ Cancel() called, sending SIGTERM")
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
//return nil
|
||||
|
||||
// this error is returned by cmd.Wait(), and can be used to
|
||||
// single an error when the process couldn't be normally terminated
|
||||
// but since a SIGTERM is sent, it's probably ok to return a nil
|
||||
// as WaitDelay timing out will override the any error set here.
|
||||
//
|
||||
// test by enabling/disabling -ignore-sig-term on the process
|
||||
// with -ignore-sig-term enabled, cmd.Wait() will have "signal: killed"
|
||||
// without it, it will show the "new error from cancel"
|
||||
return errors.New("error from cmd.Cancel()") // sets error returned by cmd.Wait()
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
fmt.Println("Error starting process:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// catch signals. Calls cancel() which will cause cmd.Wait() to return and
|
||||
// this program to eventually exit gracefully.
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
signal := <-sigChan
|
||||
fmt.Printf("✔︎ Received signal: %d, Killing process... with cancel before exiting\n", signal)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
fmt.Printf("✔︎ Parent Pid: %d, Process Pid: %d\n", os.Getpid(), cmd.Process.Pid)
|
||||
fmt.Println("✔︎ Process started, cmd.Wait() ... ")
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Println("✔︎ cmd.Wait returned, Error:", err)
|
||||
} else {
|
||||
fmt.Println("✔︎ cmd.Wait returned, Process exited on its own")
|
||||
}
|
||||
fmt.Println("✔︎ Child process exited, Done.")
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
The rerank-test.json data is from https://github.com/ggerganov/llama.cpp/pull/9510
|
||||
|
||||
To run it:
|
||||
> curl http://127.0.0.1:8080/v1/rerank -H "Content-Type: application/json" -d @reranker-test.json -v | jq .
|
||||
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"model": "bge-reranker",
|
||||
"query": "Organic skincare products for sensitive skin",
|
||||
"top_n": 3,
|
||||
"documents": [
|
||||
"Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion.",
|
||||
"New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look.",
|
||||
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille: Erleben Sie die wohltuende Wirkung unserer Bio-Hautpflege, speziell für empfindliche Haut entwickelt. Mit den beruhigenden Eigenschaften von Aloe Vera und Kamille pflegen und schützen unsere Produkte Ihre Haut auf natürliche Weise. Verabschieden Sie sich von Hautirritationen und genießen Sie einen strahlenden Teint.",
|
||||
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken: Tauchen Sie ein in die Welt der modernen Schönheit mit den neuesten Make-up-Trends. Kräftige, lebendige Farben und innovative Techniken setzen neue Maßstäbe. Von auffälligen Eyelinern bis hin zu holografischen Highlightern – lassen Sie Ihrer Kreativität freien Lauf und setzen Sie jedes Mal ein Statement.",
|
||||
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla: Descubre el poder de la naturaleza con nuestra línea de cuidado de la piel orgánico, diseñada especialmente para pieles sensibles. Enriquecidos con aloe vera y manzanilla, estos productos ofrecen una hidratación y protección suave. Despídete de las irritaciones y saluda a una piel radiante y saludable.",
|
||||
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look.",
|
||||
"针对敏感肌专门设计的天然有机护肤产品:体验由芦荟和洋甘菊提取物带来的自然呵护。我们的护肤产品特别为敏感肌设计,温和滋润,保护您的肌肤不受刺激。让您的肌肤告别不适,迎来健康光彩。",
|
||||
"新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。",
|
||||
"敏感肌のために特別に設計された天然有機スキンケア製品: アロエベラとカモミールのやさしい力で、自然の抱擁を感じてください。敏感肌用に特別に設計された私たちのスキンケア製品は、肌に優しく栄養を与え、保護します。肌トラブルにさようなら、輝く健康な肌にこんにちは。",
|
||||
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています: 今シーズンのメイクアップトレンドは、大胆な色彩と革新的な技術に注目しています。ネオンアイライナーからホログラフィックハイライターまで、クリエイティビティを解き放ち、毎回ユニークなルックを演出しましょう。"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
|
||||
silent := flag.Bool("silent", false, "disable all logging")
|
||||
|
||||
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
// Create a new Gin router
|
||||
r := gin.New()
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
// Check if streaming is requested
|
||||
// Query is checked instead of JSON body since that event stream conflicts with other tests
|
||||
isStreaming := c.Query("stream") == "true"
|
||||
|
||||
if isStreaming {
|
||||
// Set headers for streaming
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
// Send 10 "asdf" tokens
|
||||
for i := 0; i < 10; i++ {
|
||||
data := gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []gin.H{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": gin.H{
|
||||
"content": "asdf",
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
c.SSEvent("message", data)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// Send final data with usage info
|
||||
finalData := gin.H{
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
// add timings to simulate llama.cpp
|
||||
"timings": gin.H{
|
||||
"prompt_n": 25,
|
||||
"prompt_ms": 13,
|
||||
"predicted_n": 10,
|
||||
"predicted_ms": 17,
|
||||
"predicted_per_second": 10,
|
||||
},
|
||||
}
|
||||
c.SSEvent("message", finalData)
|
||||
c.Writer.Flush()
|
||||
|
||||
// Send [DONE]
|
||||
c.SSEvent("message", "[DONE]")
|
||||
c.Writer.Flush()
|
||||
} else {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
"timings": gin.H{
|
||||
"prompt_n": 25,
|
||||
"prompt_ms": 13,
|
||||
"predicted_n": 10,
|
||||
"predicted_ms": 17,
|
||||
"predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != *expectedModel {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
// llama-server compatibility: /completion
|
||||
r.POST("/completion", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// issue #41
|
||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||
// Parse the multipart form
|
||||
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the model from the form values
|
||||
model := c.Request.FormValue("model")
|
||||
|
||||
if model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the file from the form
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read the file content to get its size
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
fileSize := len(fileBytes)
|
||||
|
||||
// Return a JSON response with the model and transcription text including file size
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||
"model": model,
|
||||
|
||||
// expose some header values for testing
|
||||
"h_content_type": c.GetHeader("Content-Type"),
|
||||
"h_content_length": c.GetHeader("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/v1/audio/voices", func(c *gin.Context) {
|
||||
model := c.Query("model")
|
||||
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
if echo == "" {
|
||||
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
}
|
||||
|
||||
// Parse the duration
|
||||
if delay == "" {
|
||||
delay = "100ms"
|
||||
}
|
||||
|
||||
t, err := time.ParseDuration(delay)
|
||||
if err != nil {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/plain")
|
||||
for _, char := range echo {
|
||||
c.Writer.Write([]byte(string(char)))
|
||||
c.Writer.Flush()
|
||||
|
||||
// wait
|
||||
<-time.After(t)
|
||||
}
|
||||
})
|
||||
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/env", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
|
||||
// Get environment variables
|
||||
envVars := os.Environ()
|
||||
|
||||
// Write each environment variable to the response
|
||||
for _, envVar := range envVars {
|
||||
c.String(200, envVar)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up the /health endpoint handler function
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||
})
|
||||
|
||||
// SD API endpoints
|
||||
r.POST("/sdapi/v1/txt2img", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": modelName,
|
||||
"images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
r.POST("/sdapi/v1/img2img", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": modelName,
|
||||
"images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/sdapi/v1/loras", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"loras": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
address := "127.0.0.1:" + *port // Address with the specified port
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: address,
|
||||
Handler: r.Handler(),
|
||||
}
|
||||
|
||||
// Disable logging if the --silent flag is set
|
||||
if *silent {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
if !*silent {
|
||||
fmt.Printf("My PID: %d\n", os.Getpid())
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("simple-responder listening on %s\n", address)
|
||||
// service connections
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("simple-responder err: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown the server with
|
||||
// a timeout of 5 seconds.
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
countSigInt := 0
|
||||
|
||||
runloop:
|
||||
for {
|
||||
signal := <-sigChan
|
||||
switch signal {
|
||||
case syscall.SIGINT:
|
||||
countSigInt++
|
||||
if countSigInt > 1 {
|
||||
break runloop
|
||||
} else {
|
||||
log.Println("Received SIGINT, send another SIGINT to shutdown")
|
||||
}
|
||||
case syscall.SIGTERM:
|
||||
if *ignoreSigTerm {
|
||||
log.Println("Ignoring SIGTERM")
|
||||
} else {
|
||||
log.Println("Received SIGTERM, shutting down")
|
||||
break runloop
|
||||
}
|
||||
default:
|
||||
break runloop
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("simple-responder shutting down")
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
# wol-proxy
|
||||
|
||||
wol-proxy automatically wakes up a suspended llama-swap server using Wake-on-LAN when requests are received.
|
||||
|
||||
When a request arrives and llama-swap is unavailable, wol-proxy sends a WOL packet and holds the request until the server becomes available. If the server doesn't respond within the timeout period (default: 60 seconds), the request is dropped.
|
||||
|
||||
This utility helps conserve energy by allowing GPU-heavy servers to remain suspended when idle, as they can consume hundreds of watts even when not actively processing requests.
|
||||
|
||||
## Usage
|
||||
|
||||
```shell
|
||||
# minimal
|
||||
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080
|
||||
|
||||
# everything
|
||||
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080 \
|
||||
# use debug log level
|
||||
-log debug \
|
||||
# altenerative listening port
|
||||
-listen localhost:9999 \
|
||||
# seconds to hold requests waiting for upstream to be ready
|
||||
-timeout 30
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
`GET /status` - that's it. Everything else is proxied to the upstream server.
|
||||
@@ -0,0 +1,64 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Loading...</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.loader {
|
||||
text-align: center;
|
||||
}
|
||||
.stats {
|
||||
font-size: 18px;
|
||||
color: #333;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.stats-label {
|
||||
color: #666;
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="loader">
|
||||
<p>Waking up upstream server...</p>
|
||||
<div class="stats">
|
||||
<div><span class="stats-label">Time elapsed:</span> <span id="elapsed">0s</span></div>
|
||||
<div><span id="attempts"> </span></div>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
var startTime = Date.now();
|
||||
var attempts = 0;
|
||||
|
||||
setInterval(function() {
|
||||
var elapsed = (Date.now() - startTime) / 1000;
|
||||
document.getElementById('elapsed').textContent = elapsed.toFixed(1) + 's';
|
||||
}, 100);
|
||||
|
||||
// Check status every second
|
||||
setInterval(function() {
|
||||
attempts++;
|
||||
var dots = '.'.repeat((attempts % 10) || 10);
|
||||
document.getElementById('attempts').textContent = dots;
|
||||
|
||||
fetch('/status')
|
||||
.then(function(r) { return r.text(); })
|
||||
.then(function(t) {
|
||||
if (t.indexOf('status: ready') !== -1) {
|
||||
location.reload();
|
||||
}
|
||||
})
|
||||
.catch(function() {});
|
||||
}, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,333 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:embed index.html
|
||||
var loadingPageHTML string
|
||||
|
||||
var (
|
||||
flagMac = flag.String("mac", "", "mac address to send WoL packet to")
|
||||
flagUpstream = flag.String("upstream", "", "upstream proxy address to send requests to")
|
||||
flagListen = flag.String("listen", ":8080", "listen address to listen on")
|
||||
flagLog = flag.String("log", "info", "log level (debug, info, warn, error)")
|
||||
flagTimeout = flag.Int("timeout", 60, "seconds requests wait for upstream response before failing")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
switch *flagLog {
|
||||
case "debug":
|
||||
slog.SetLogLoggerLevel(slog.LevelDebug)
|
||||
case "info":
|
||||
slog.SetLogLoggerLevel(slog.LevelInfo)
|
||||
case "warn":
|
||||
slog.SetLogLoggerLevel(slog.LevelWarn)
|
||||
case "error":
|
||||
slog.SetLogLoggerLevel(slog.LevelError)
|
||||
default:
|
||||
slog.Error("invalid log level", "logLevel", *flagLog)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate flags
|
||||
if *flagListen == "" {
|
||||
slog.Error("listen address is required")
|
||||
return
|
||||
}
|
||||
|
||||
if *flagMac == "" {
|
||||
slog.Error("mac address is required")
|
||||
return
|
||||
}
|
||||
|
||||
if *flagTimeout < 1 {
|
||||
slog.Error("timeout must be greater than 0")
|
||||
return
|
||||
}
|
||||
|
||||
var upstreamURL *url.URL
|
||||
var err error
|
||||
// validate mac address
|
||||
if _, err = net.ParseMAC(*flagMac); err != nil {
|
||||
slog.Error("invalid mac address", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if *flagUpstream == "" {
|
||||
slog.Error("upstream proxy address is required")
|
||||
return
|
||||
} else {
|
||||
upstreamURL, err = url.ParseRequestURI(*flagUpstream)
|
||||
if err != nil {
|
||||
slog.Error("error parsing upstream url", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxy := newProxy(upstreamURL)
|
||||
server := &http.Server{
|
||||
Addr: *flagListen,
|
||||
Handler: proxy,
|
||||
}
|
||||
|
||||
// start the server
|
||||
go func() {
|
||||
slog.Info("server starting on", "address", *flagListen)
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
slog.Error("error starting server", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// graceful shutdown
|
||||
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
<-ctx.Done()
|
||||
server.Close()
|
||||
}
|
||||
|
||||
type upstreamStatus string
|
||||
|
||||
const (
|
||||
notready upstreamStatus = "not ready"
|
||||
ready upstreamStatus = "ready"
|
||||
)
|
||||
|
||||
type proxyServer struct {
|
||||
upstreamProxy *httputil.ReverseProxy
|
||||
failCount int
|
||||
statusMutex sync.RWMutex
|
||||
status upstreamStatus
|
||||
}
|
||||
|
||||
func newProxy(url *url.URL) *proxyServer {
|
||||
p := httputil.NewSingleHostReverseProxy(url)
|
||||
proxy := &proxyServer{
|
||||
upstreamProxy: p,
|
||||
status: notready,
|
||||
failCount: 0,
|
||||
}
|
||||
|
||||
// start a goroutine to monitor upstream status via SSE
|
||||
go func() {
|
||||
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
|
||||
client := &http.Client{
|
||||
Timeout: 0, // No timeout for SSE connection
|
||||
}
|
||||
|
||||
waitDuration := 10 * time.Second
|
||||
|
||||
for {
|
||||
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
|
||||
|
||||
req, err := http.NewRequest("GET", eventsUrl, nil)
|
||||
if err != nil {
|
||||
slog.Warn("failed to create SSE request", "error", err)
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(waitDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
slog.Error("failed to connect to SSE endpoint", "error", err)
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// Successfully connected to SSE endpoint
|
||||
slog.Info("connected to SSE endpoint, upstream ready")
|
||||
proxy.setStatus(ready)
|
||||
proxy.resetFailures()
|
||||
|
||||
// Read from the SSE stream to detect disconnection
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
// use a fairly large buffer to avoid scanner errors when reading large SSE events
|
||||
buf := make([]byte, 0, 1024*1024*2)
|
||||
scanner.Buffer(buf, 1024*1024*2)
|
||||
events := 0
|
||||
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||
fmt.Print("Events: ")
|
||||
}
|
||||
for scanner.Scan() {
|
||||
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||
// Just read the events to keep connection alive
|
||||
// We don't need to process the event data
|
||||
events++
|
||||
fmt.Printf("%d, ", events)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
if err := scanner.Err(); err != nil {
|
||||
slog.Error("error reading from SSE stream", "error", err)
|
||||
}
|
||||
|
||||
// Connection closed or error occurred
|
||||
_ = resp.Body.Close()
|
||||
slog.Info("SSE connection closed, upstream not ready")
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
|
||||
// Wait before reconnecting
|
||||
time.Sleep(waitDuration)
|
||||
}
|
||||
}()
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "GET" && r.URL.Path == "/status" {
|
||||
status := string(p.getStatus())
|
||||
failCount := p.getFailures()
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprintf(w, "status: %s\n", status)
|
||||
fmt.Fprintf(w, "failures: %d\n", failCount)
|
||||
return
|
||||
}
|
||||
|
||||
if p.getStatus() == notready {
|
||||
path := r.URL.Path
|
||||
if strings.HasPrefix(path, "/api/events") {
|
||||
slog.Debug("Skipping wake up", "req", path)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
|
||||
if err := sendMagicPacket(*flagMac); err != nil {
|
||||
slog.Warn("failed to send magic WoL packet", "error", err)
|
||||
}
|
||||
|
||||
// For root or UI path requests, return loading page with status polling
|
||||
// the web page will do the polling and redirect when ready
|
||||
if path == "/" || strings.HasPrefix(path, "/ui/") {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, loadingPageHTML)
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
timeout, cancel := context.WithTimeout(context.Background(), time.Duration(*flagTimeout)*time.Second)
|
||||
defer cancel()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeout.Done():
|
||||
slog.Info("timeout waiting for upstream to be ready")
|
||||
http.Error(w, "timeout", http.StatusRequestTimeout)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if p.getStatus() == ready {
|
||||
ticker.Stop()
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.upstreamProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (p *proxyServer) getStatus() upstreamStatus {
|
||||
p.statusMutex.RLock()
|
||||
defer p.statusMutex.RUnlock()
|
||||
return p.status
|
||||
}
|
||||
|
||||
func (p *proxyServer) setStatus(status upstreamStatus) {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.status = status
|
||||
}
|
||||
|
||||
func (p *proxyServer) incFail(num int) {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.failCount += num
|
||||
}
|
||||
|
||||
func (p *proxyServer) getFailures() int {
|
||||
p.statusMutex.RLock()
|
||||
defer p.statusMutex.RUnlock()
|
||||
return p.failCount
|
||||
}
|
||||
|
||||
func (p *proxyServer) resetFailures() {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.failCount = 0
|
||||
}
|
||||
|
||||
func sendMagicPacket(macAddr string) error {
|
||||
hwAddr, err := net.ParseMAC(macAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hwAddr) != 6 {
|
||||
return errors.New("invalid MAC address")
|
||||
}
|
||||
|
||||
// Create the magic packet.
|
||||
packet := make([]byte, 102)
|
||||
// Add 6 bytes of 0xFF.
|
||||
for i := 0; i < 6; i++ {
|
||||
packet[i] = 0xFF
|
||||
}
|
||||
// Repeat the MAC address 16 times.
|
||||
for i := 1; i <= 16; i++ {
|
||||
copy(packet[i*6:], hwAddr)
|
||||
}
|
||||
|
||||
// Send the packet using UDP.
|
||||
addr := net.UDPAddr{
|
||||
IP: net.IPv4bcast,
|
||||
Port: 9,
|
||||
}
|
||||
conn, err := net.DialUDP("udp", nil, &addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,520 @@
|
||||
{
|
||||
"$schema": "https://json-schema.org/draft-07/schema#",
|
||||
"$id": "llama-swap-config-schema.json",
|
||||
"title": "llama-swap configuration",
|
||||
"description": "Configuration file for llama-swap",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"models"
|
||||
],
|
||||
"definitions": {
|
||||
"macros": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"minLength": 0,
|
||||
"maxLength": 1024
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
},
|
||||
"propertyNames": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 64,
|
||||
"pattern": "^[a-zA-Z0-9_-]+$",
|
||||
"not": {
|
||||
"enum": [
|
||||
"PORT",
|
||||
"MODEL_ID"
|
||||
]
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
|
||||
},
|
||||
"timeouts": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connect": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP connection timeout in seconds. Set to 0 to disable."
|
||||
},
|
||||
"keepalive": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP keepalive timeout in seconds. Set to 0 to disable."
|
||||
},
|
||||
"responseHeader": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Time to wait for response headers in seconds. Set to 0 to disable."
|
||||
},
|
||||
"tlsHandshake": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 10,
|
||||
"description": "TLS handshake timeout in seconds. Set to 0 to disable."
|
||||
},
|
||||
"expectContinue": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 1,
|
||||
"description": "Expect-Continue timeout in seconds. Set to 0 to disable."
|
||||
},
|
||||
"idleConn": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 90,
|
||||
"description": "Idle connection timeout in seconds. Set to 0 to disable."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections."
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"healthCheckTimeout": {
|
||||
"type": "integer",
|
||||
"minimum": 15,
|
||||
"default": 120,
|
||||
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
||||
},
|
||||
"globalTTL": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Default TTL for all models in seconds, 0 means no TTL and models will never be automatically unloaded"
|
||||
},
|
||||
"logLevel": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"debug",
|
||||
"info",
|
||||
"warn",
|
||||
"error"
|
||||
],
|
||||
"default": "info",
|
||||
"description": "Sets the logging value. Valid values: debug, info, warn, error."
|
||||
},
|
||||
"logTimeFormat": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"",
|
||||
"ansic",
|
||||
"unixdate",
|
||||
"rubydate",
|
||||
"rfc822",
|
||||
"rfc822z",
|
||||
"rfc850",
|
||||
"rfc1123",
|
||||
"rfc1123z",
|
||||
"rfc3339",
|
||||
"rfc3339nano",
|
||||
"kitchen",
|
||||
"stamp",
|
||||
"stampmilli",
|
||||
"stampmicro",
|
||||
"stampnano"
|
||||
],
|
||||
"default": "",
|
||||
"description": "Enables and sets the logging timestamp format. Valid values: \"\", \"ansic\", \"unixdate\", \"rubydate\", \"rfc822\", \"rfc822z\", \"rfc850\", \"rfc1123\", \"rfc1123z\", \"rfc3339\", \"rfc3339nano\", \"kitchen\", \"stamp\", \"stampmilli\", \"stampmicro\", and \"stampnano\". For more info, read: https://pkg.go.dev/time#pkg-constants"
|
||||
},
|
||||
"metricsMaxInMemory": {
|
||||
"type": "integer",
|
||||
"default": 1000,
|
||||
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
||||
},
|
||||
"captureBuffer": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 5,
|
||||
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
|
||||
},
|
||||
"startPort": {
|
||||
"type": "integer",
|
||||
"default": 5800,
|
||||
"description": "Starting port number for the automatic ${PORT} macro. The ${PORT} macro is incremented for every model that uses it."
|
||||
},
|
||||
"sendLoadingState": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Inject loading status updates into the reasoning field. When true, a stream of loading messages will be sent to the client."
|
||||
},
|
||||
"includeAliasesInList": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Present aliases within the /v1/models OpenAI API listing. when true, model aliases will be output to the API model listing duplicating all fields except for Id so chat UIs can use the alias equivalent to the original."
|
||||
},
|
||||
"macros": {
|
||||
"$ref": "#/definitions/macros"
|
||||
},
|
||||
"models": {
|
||||
"type": "object",
|
||||
"description": "A dictionary of model configurations. Each key is a model's ID. Model settings have defaults if not defined. The model's ID is available as ${MODEL_ID}.",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"cmd"
|
||||
],
|
||||
"properties": {
|
||||
"macros": {
|
||||
"$ref": "#/definitions/macros"
|
||||
},
|
||||
"cmd": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "Command to run to start the inference server. Macros can be used. Comments allowed with |."
|
||||
},
|
||||
"cmdStop": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Command to run to stop the model gracefully. Uses ${PID} macro for upstream process id. If empty, default shutdown behavior is used."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"maxLength": 128,
|
||||
"description": "Display name for the model. Used in v1/models API response."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"maxLength": 1024,
|
||||
"description": "Description for the model. Used in v1/models API response."
|
||||
},
|
||||
"env": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-Z_][A-Z0-9_]*=.*$"
|
||||
},
|
||||
"default": [],
|
||||
"description": "Array of environment variables to inject into cmd's environment. Each value is a string in ENV_NAME=value format."
|
||||
},
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"default": "http://localhost:${PORT}",
|
||||
"format": "uri",
|
||||
"description": "URL where llama-swap routes API requests. If custom port is used in cmd, this must be set."
|
||||
},
|
||||
"aliases": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Alternative model names for this configuration. Must be unique globally."
|
||||
},
|
||||
"checkEndpoint": {
|
||||
"type": "string",
|
||||
"default": "/health",
|
||||
"pattern": "^/.*$|^none$",
|
||||
"description": "URL path to check if the server is ready. Use 'none' to skip health checking."
|
||||
},
|
||||
"ttl": {
|
||||
"type": "integer",
|
||||
"minimum": -1,
|
||||
"default": -1,
|
||||
"description": "Automatically unload the model after ttl seconds. -1 uses the global TTL value, 0 disables unloading. Must be >0 to enable."
|
||||
},
|
||||
"useModelName": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Override the model name sent to upstream server. Useful if upstream expects a different name."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
},
|
||||
"setParamsByID": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
},
|
||||
"default": {},
|
||||
"description": "Dictionary mapping requested model IDs (or aliases) to parameters to set/override in requests. Applied after setParams and can override those values. Useful with aliases to vary behaviour depending on which alias the client used (e.g. different reasoning_effort per alias). Keys support ${MODEL_ID} macro substitution. Protected params like 'model' cannot be overridden."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings. Supports stripParams, setParams, and setParamsByID."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of arbitrary values included in /v1/models. Can contain complex types. Only passed through in /v1/models responses."
|
||||
},
|
||||
"concurrencyLimit": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Overrides allowed number of active parallel requests to a model. 0 uses internal default of 10. >0 overrides default. Requests exceeding limit get HTTP 429."
|
||||
},
|
||||
"sendLoadingState": {
|
||||
"type": "boolean",
|
||||
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
|
||||
},
|
||||
"unlisted": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
|
||||
},
|
||||
"timeouts": {
|
||||
"$ref": "#/definitions/timeouts"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"groups": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
},
|
||||
"matrix": {
|
||||
"type": "object",
|
||||
"description": "Solver-based alternative to groups. Declares valid combinations of concurrent models. The solver minimizes eviction cost when swapping. A config must use either groups or matrix, not both.",
|
||||
"required": [
|
||||
"vars",
|
||||
"sets"
|
||||
],
|
||||
"properties": {
|
||||
"vars": {
|
||||
"type": "object",
|
||||
"description": "Short names for models. Keys must be alphanumeric, 1-8 characters. All sets and evict_costs must use these IDs.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z0-9]{1,8}$"
|
||||
}
|
||||
},
|
||||
"evict_costs": {
|
||||
"type": "object",
|
||||
"description": "Relative cost of evicting a running model. Models not listed default to 1. Values must be positive integers.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"sets": {
|
||||
"type": "object",
|
||||
"description": "Named sets of concurrent model combinations. Values are DSL strings using & (AND), | (OR), () (grouping), and +ref (inline another set). Definition order is used for tie-breaking.",
|
||||
"minProperties": 1,
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"hooks": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"on_startup": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"preload": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": [],
|
||||
"description": "List of model IDs to load on startup. Model names must match keys in models. When preloading multiple models, define a group to prevent swapping."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Actions to perform on startup. Only supported action is preload."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||
},
|
||||
"logToStdout": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"proxy",
|
||||
"upstream",
|
||||
"both",
|
||||
"none"
|
||||
],
|
||||
"default": "proxy",
|
||||
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
|
||||
},
|
||||
"apiKeys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
|
||||
},
|
||||
"peers": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"proxy",
|
||||
"models"
|
||||
],
|
||||
"properties": {
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
|
||||
},
|
||||
"apiKey": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"description": "A list of models served by the peer."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
|
||||
},
|
||||
"timeouts": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connect": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP connection timeout in seconds."
|
||||
},
|
||||
"keepalive": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 30,
|
||||
"description": "TCP keepalive connection timeout in seconds."
|
||||
},
|
||||
"responseHeader": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Time to wait for response headers in seconds."
|
||||
},
|
||||
"tlsHandshake": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 10,
|
||||
"description": "TLS handshake timeout in seconds."
|
||||
},
|
||||
"idleConn": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 90,
|
||||
"description": "Idle connection timeout in seconds."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Timeout settings for proxy connections to this peer."
|
||||
}
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
}
|
||||
},
|
||||
"allOf": [
|
||||
{
|
||||
"if": {
|
||||
"required": ["groups"]
|
||||
},
|
||||
"then": {
|
||||
"not": {
|
||||
"required": ["matrix"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"if": {
|
||||
"required": ["matrix"]
|
||||
},
|
||||
"then": {
|
||||
"not": {
|
||||
"required": ["groups"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
+479
-42
@@ -1,53 +1,490 @@
|
||||
# Seconds to wait for llama.cpp to be available to serve requests
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 15
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
# 💡 Tip - Use an LLM with this file!
|
||||
# ====================================
|
||||
# This example configuration is written to be LLM friendly. Try
|
||||
# copying this file into an LLM and asking it to explain or generate
|
||||
# sections for you.
|
||||
# ====================================
|
||||
|
||||
# Usage notes:
|
||||
# - Below are all the available configuration options for llama-swap.
|
||||
# - Settings noted as "required" must be in your configuration file
|
||||
# - Settings noted as "optional" can be omitted
|
||||
|
||||
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||
# - optional, default: 120
|
||||
# - minimum value is 15 seconds, anything less will be set to this value
|
||||
healthCheckTimeout: 500
|
||||
|
||||
# logLevel: sets the logging value
|
||||
# - optional, default: info
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||
# - optional, default: 10
|
||||
# - set to 0 to disable
|
||||
captureBuffer: 15
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||
# field
|
||||
# - optional, default: false
|
||||
# - when true, a stream of loading messages will be sent to the client in the
|
||||
# reasoning field so chat UIs can show that loading is in progress.
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# globalTTL: the default TTL in seconds before unloading a model
|
||||
# - optional, default: 0 (never automatically unload)
|
||||
# - must be >= 0
|
||||
globalTTL: 0
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||
# - useful for reducing common configuration settings
|
||||
# - macro names are strings and must be less than 64 characters
|
||||
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||
# - env macros are substituted first, before regular macros
|
||||
# - if the env var is not set, config loading will fail with an error
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
/path/to/llama-server/llama-server-ec9e0301
|
||||
--port ${PORT}
|
||||
|
||||
"default_ctx": 4096
|
||||
|
||||
# Example of macro-in-macro usage. macros can contain other macros
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# Example of environment variable macros
|
||||
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||
# - useful for paths, secrets, or machine-specific configuration
|
||||
"models_dir": "${env.HOME}/models"
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
|
||||
# use environment variable macros to keep secrets out of the config
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
# - model settings have default values that are used if they are not defined here
|
||||
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||
# - below are examples of the all the settings a model can have
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
models/llama-server-osx
|
||||
--port 9001
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
proxy: http://127.0.0.1:9001
|
||||
# keys are the model names used in API requests
|
||||
"gpt-oss-120b":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
# - model level macros follow the same rules as global macros
|
||||
macros:
|
||||
"default_ctx": 16384
|
||||
"temp": 0.7
|
||||
|
||||
# list of model name aliases this llama.cpp instance can serve
|
||||
aliases:
|
||||
- gpt-4o-mini
|
||||
# cmd: the command to run to start the inference server.
|
||||
# - required
|
||||
# - it is just a string, similar to what you would run on the CLI
|
||||
# - using `|` allows for comments in the command, these will be parsed out
|
||||
# - macros can be used within cmd
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/gpt-oss-120B.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
# check this path for a HTTP 200 response for the server to be ready
|
||||
checkEndpoint: /health
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "gpt-oss 120B"
|
||||
|
||||
# unload model after 5 seconds
|
||||
ttl: 5
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A thinking model from OpenAI"
|
||||
|
||||
"qwen":
|
||||
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9002
|
||||
aliases:
|
||||
- gpt-3.5-turbo
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
# - each value is a single string
|
||||
# - in the format: ENV_NAME=value
|
||||
env:
|
||||
- CUDA_VISIBLE_DEVICES=0,1
|
||||
- env1=hello
|
||||
cmd: build/simple-responder --port 8999
|
||||
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||
|
||||
# proxy: the URL where llama-swap routes API requests
|
||||
# - optional, default: http://localhost:${PORT}
|
||||
# - if you used ${PORT} in cmd this can be omitted
|
||||
# - if you use a custom port in cmd this *must* be set
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# use "none" to skip check. Caution this may cause some requests to fail
|
||||
# until the upstream server is ready for traffic
|
||||
checkEndpoint: none
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
# - all requests wait until the endpoint is ready or fails
|
||||
# - use "none" to skip endpoint health checking
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# don't use these, just for testing if things are broken
|
||||
"broken":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: -1 (use global default)
|
||||
# - ttl values must be a value greater than or equal to 0
|
||||
# - a ttl of -1 will use the global TTL value as the default
|
||||
# - a ttl of 0 will mean never unload
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
# useModelName: override the model name that is sent to upstream server
|
||||
# - optional, default: ""
|
||||
# - useful for when the upstream server expects a specific model name that
|
||||
# is different from the model's ID
|
||||
useModelName: "openai/gpt-oss-120B"
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as peer filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for server side enforcement of sampling parameters
|
||||
# - the `model` parameter can never be removed
|
||||
# - can be any JSON key in the request body
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for enforcing specific parameter values
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - always runs for the model
|
||||
setParams:
|
||||
# Example: enforce specific sampling parameters
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
|
||||
# setParamsByID: a dictionary of parameters to set based the model ID
|
||||
# - optional, default: empty dictionary
|
||||
# - combine with aliases to create variant behaviour without reloading the model
|
||||
# - parameters are set in the request body JSON
|
||||
# - run after setParams so it will override any settings
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - model aliases will be automatically created for each key
|
||||
setParamsByID:
|
||||
"${MODEL_ID}":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: medium
|
||||
"${MODEL_ID}:high":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: low
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
# - metadata is only passed through in /v1/models responses
|
||||
metadata:
|
||||
# port will remain an integer
|
||||
port: ${PORT}
|
||||
|
||||
# the ${temp} macro will remain a float
|
||||
temperature: ${temp}
|
||||
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||
|
||||
a_list:
|
||||
- 1
|
||||
- 1.23
|
||||
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||
|
||||
an_obj:
|
||||
a: "1"
|
||||
b: 2
|
||||
# objects can contain complex types with macro substitution
|
||||
# becomes: c: [0.7, false, "model: llama"]
|
||||
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||
|
||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||
# - optional, default: 0
|
||||
# - useful for limiting the number of active parallel requests a model can process
|
||||
# - must be set per model
|
||||
# - any number greater than 0 will override the internal default value of 10
|
||||
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# timeouts: configure proxy connection timeouts for this model
|
||||
# - optional, defaults shown below
|
||||
# - useful for models running on slower hardware that need longer timeouts
|
||||
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
|
||||
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
|
||||
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
|
||||
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
|
||||
# - idleConn: idle connection timeout in seconds, default: 90 seconds
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
keepalive: 0
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
# - optional, default: false
|
||||
# - unlisted models do not show up in /v1/models api requests
|
||||
# - can be requested as normal through all apis
|
||||
unlisted: true
|
||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker example:
|
||||
# container runtimes like Docker and Podman can be used reliably with
|
||||
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
cmd: |
|
||||
docker run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# cmdStop: command to run to stop the model gracefully
|
||||
# - optional, default: ""
|
||||
# - useful for stopping commands managed by another system
|
||||
# - the upstream's process id is available in the ${PID} macro
|
||||
#
|
||||
# When empty, llama-swap has this default behaviour:
|
||||
# - on POSIX systems: a SIGTERM signal is sent
|
||||
# - on Windows, calls taskkill to stop the process
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# =============================================================================
|
||||
# matrix: run concurrent models with a solver-based swap DSL
|
||||
# =============================================================================
|
||||
#
|
||||
# Note:
|
||||
# A config must use either a matrix or legacy groups, not both. A configuration error
|
||||
# will occur if both are defined. Configuration examples for legacy Groups can be found:
|
||||
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||
#
|
||||
# The matrix declares valid combinations of models that can run concurrently.
|
||||
# When a model is requested, the solver finds the cheapest way to make it
|
||||
# available by evicting as few (and least costly) running models as possible.
|
||||
#
|
||||
# Solver behavior:
|
||||
# 1. Request arrives for model X
|
||||
# 2. If X is already running, forward immediately. Done.
|
||||
# 3. Find all sets containing X
|
||||
# 4. For each candidate set, compute cost: sum of evict_costs for
|
||||
# every running model NOT in that set
|
||||
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
||||
# 6. Evict what needs to stop. Start X. Forward request.
|
||||
#
|
||||
# Subset semantics: a set [a, b, c] means any subset is valid.
|
||||
# Only the requested model is started — others are not preloaded.
|
||||
#
|
||||
# A model not appearing in any set can only run alone.
|
||||
#
|
||||
matrix:
|
||||
# vars: short names for models (alphanumeric, 1-8 chars)
|
||||
# - required for sets and evict_costs settings
|
||||
# - each entry is a short name to a real model ID. Do not use an alias
|
||||
# - used to keep set DSL logic short and easier to read
|
||||
# - sets and evict_costs only use identifiers defined in vars
|
||||
vars:
|
||||
g: gemma-model
|
||||
q: qwen-model
|
||||
m: mistral-model
|
||||
v: voxtral-model
|
||||
e: reranker-model
|
||||
L: llama-70B
|
||||
sd: stable-diffusion
|
||||
|
||||
# evict_costs: relative cost of losing a running model (default: 1)
|
||||
evict_costs:
|
||||
v: 50 # vllm backend, slow cold start
|
||||
L: 30 # 70B weights, slow to load
|
||||
|
||||
# sets: named sets of concurrent model combinations
|
||||
# Values are DSL strings with operators:
|
||||
# & AND (models run together)
|
||||
# | OR (alternatives)
|
||||
# () grouping
|
||||
# +ref inline another set's expression
|
||||
#
|
||||
# Expansion examples:
|
||||
# "L" → [L]
|
||||
# "a & b" → [a, b]
|
||||
# "a | b" → [a], [b]
|
||||
# "(a | b) & c" → [a, c], [b, c]
|
||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||
# "+llms & v" → expands llms inline, then applies & v
|
||||
sets:
|
||||
# LLM + TTS: switching between g/q/m won't evict v
|
||||
# expands to: [g,v], [q,v], [m,v]
|
||||
standard: "(g | q | m) & v"
|
||||
|
||||
# LLM + TTS + reranker
|
||||
# expands to: [g,v,e], [q,v,e]
|
||||
with_rerank: "(g | q) & v & e"
|
||||
|
||||
# LLM + image generation, no TTS
|
||||
# expands to: [g,sd], [q,sd]
|
||||
creative: "(g | q) & sd"
|
||||
|
||||
# 70B model uses all GPUs, can only run alone
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
hooks:
|
||||
# on_startup: a dictionary of actions to perform on startup
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported action is preload
|
||||
on_startup:
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
# - can be a string or a macro
|
||||
apiKey: ${env.OPENROUTER_API_KEY}
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
# timeouts: configure proxy connection timeouts for this peer
|
||||
# - optional, defaults shown below
|
||||
# - useful when the peer runs on slower hardware
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
keepalive: 30
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# filters: a dictionary of filter settings for peer requests
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as model filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for removing parameters that the peer doesn't support
|
||||
# - the `model` parameter can never be removed
|
||||
stripParams: "temperature, top_p"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for injecting provider-specific settings like data retention policies
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce zero-data-retention for OpenRouter
|
||||
provider:
|
||||
data_collection: "deny"
|
||||
zdr: true
|
||||
|
||||
Executable
+164
@@ -0,0 +1,164 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
# use this to test locally, example:
|
||||
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
|
||||
# you need read:package scope on the token. Generate a personal access token with
|
||||
# the scopes: gist, read:org, repo, write:packages
|
||||
# then: gh auth login (and copy/paste the new token)
|
||||
|
||||
LOG_DEBUG=${LOG_DEBUG:-0}
|
||||
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
|
||||
|
||||
log_debug() {
|
||||
if [ "$LOG_DEBUG" = "1" ]; then
|
||||
echo "[DEBUG] $*"
|
||||
fi
|
||||
}
|
||||
|
||||
log_info() {
|
||||
echo "[INFO] $*"
|
||||
}
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cuda13" "cpu" "rocm")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
|
||||
log_info "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
||||
# variable, this permits testing with forked llama.cpp repositories
|
||||
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
||||
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
|
||||
|
||||
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
||||
# to enable easy container builds on forked repos
|
||||
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
# Fetches the most recent llama.cpp tag matching the given prefix
|
||||
# Handles pagination to search beyond the first 100 results
|
||||
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
|
||||
# Returns: the version number extracted from the tag
|
||||
fetch_llama_tag() {
|
||||
local tag_prefix=$1
|
||||
local page=1
|
||||
local per_page=100
|
||||
|
||||
while true; do
|
||||
log_debug "Fetching page $page for tag prefix: $tag_prefix"
|
||||
|
||||
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
|
||||
|
||||
# Check for API errors
|
||||
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
|
||||
local error_msg=$(echo "$response" | jq -r '.message')
|
||||
log_info "GitHub API error: $error_msg"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Check if response is empty array (no more pages)
|
||||
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
|
||||
log_debug "No more pages (empty response)"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract matching tag from this page
|
||||
local found_tag=$(echo "$response" | jq -r \
|
||||
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
|
||||
| sort -r | head -n1)
|
||||
|
||||
if [ -n "$found_tag" ]; then
|
||||
log_debug "Found tag: $found_tag on page $page"
|
||||
echo "$found_tag" | awk -F '-' '{print $NF}'
|
||||
return 0
|
||||
fi
|
||||
|
||||
page=$((page + 1))
|
||||
|
||||
# Safety limit to prevent infinite loops
|
||||
if [ $page -gt 50 ]; then
|
||||
log_info "Reached pagination safety limit (50 pages)"
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
LCPP_TAG=$(fetch_llama_tag "server")
|
||||
BASE_TAG=server-${LCPP_TAG}
|
||||
else
|
||||
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
|
||||
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||
fi
|
||||
|
||||
SD_TAG=master-${ARCH}
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
log_info "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
else
|
||||
log_info "LCPP_TAG: $LCPP_TAG"
|
||||
fi
|
||||
|
||||
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
|
||||
log_info "Abort: DEBUG_ABORT_BUILD set"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for CONTAINER_TYPE in non-root root; do
|
||||
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
|
||||
USER_UID=0
|
||||
USER_GID=0
|
||||
USER_HOME=/root
|
||||
|
||||
if [ "$CONTAINER_TYPE" == "non-root" ]; then
|
||||
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
|
||||
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
|
||||
USER_UID=10001
|
||||
USER_GID=10001
|
||||
USER_HOME=/app
|
||||
fi
|
||||
|
||||
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||
|
||||
# For architectures with stable-diffusion.cpp support, layer sd-server on top
|
||||
case "$ARCH" in
|
||||
"musa" | "vulkan")
|
||||
log_info "Adding sd-server to $CONTAINER_TAG"
|
||||
docker build --provenance=false -f llama-swap-sd.Containerfile \
|
||||
--build-arg BASE=${CONTAINER_TAG} \
|
||||
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
|
||||
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
|
||||
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
|
||||
esac
|
||||
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
done
|
||||
Executable
+305
@@ -0,0 +1,305 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Build script for llama-swap-docker with commit hash pinning
|
||||
#
|
||||
# Usage:
|
||||
# ./build-image.sh --cuda # Build CUDA image
|
||||
# ./build-image.sh --vulkan # Build Vulkan image
|
||||
# ./build-image.sh --cuda --no-cache # Build CUDA image without cache
|
||||
# LLAMA_COMMIT_HASH=abc123 ./build-image.sh --cuda # Override llama.cpp commit
|
||||
# LLAMA_COMMIT_HASH=b8429 ./build-image.sh --vulkan # Override llama.cpp release tag (vulkan uses prebuilt binaries)
|
||||
# WHISPER_COMMIT_HASH=def456 ./build-image.sh --vulkan # Override whisper.cpp commit
|
||||
# SD_COMMIT_HASH=ghi789 ./build-image.sh --cuda # Override stable-diffusion.cpp commit
|
||||
#
|
||||
# Features:
|
||||
# - Auto-detects latest commit hashes from git repos
|
||||
# - Builds llama-swap from local source code
|
||||
# - Allows environment variable overrides for reproducible builds
|
||||
# - Cache-friendly: changing commit hash busts cache appropriately
|
||||
# - Supports both CUDA and Vulkan backends (requires explicit flag)
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Parse command line arguments
|
||||
BACKEND=""
|
||||
NO_CACHE=false
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
echo ""
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
|
||||
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
|
||||
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
|
||||
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--cuda)
|
||||
BACKEND="cuda"
|
||||
;;
|
||||
--vulkan)
|
||||
BACKEND="vulkan"
|
||||
;;
|
||||
--no-cache)
|
||||
NO_CACHE=true
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:cuda or llama-swap:vulkan)"
|
||||
echo " LLAMA_COMMIT_HASH Override llama.cpp commit hash"
|
||||
echo " WHISPER_COMMIT_HASH Override whisper.cpp commit hash"
|
||||
echo " SD_COMMIT_HASH Override stable-diffusion.cpp commit hash"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate backend selection
|
||||
if [[ -z "$BACKEND" ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Configuration
|
||||
if [[ -n "${DOCKER_IMAGE_TAG:-}" ]]; then
|
||||
# User provided a custom tag, use it as-is
|
||||
:
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
DOCKER_IMAGE_TAG="llama-swap:vulkan"
|
||||
else
|
||||
DOCKER_IMAGE_TAG="llama-swap:cuda"
|
||||
fi
|
||||
DOCKER_BUILDKIT="${DOCKER_BUILDKIT:-1}"
|
||||
|
||||
# Single unified Dockerfile, backend selected via build arg
|
||||
DOCKERFILE="Dockerfile"
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Building for: Vulkan (AMD GPUs and compatible hardware)"
|
||||
else
|
||||
echo "Building for: CUDA (NVIDIA GPUs)"
|
||||
fi
|
||||
|
||||
# Git repository URLs
|
||||
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
|
||||
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
|
||||
|
||||
# Function to get the latest commit hash from a git repo's default branch
|
||||
get_latest_commit() {
|
||||
local repo_url="$1"
|
||||
local branch="${2:-master}"
|
||||
|
||||
# Try to get the latest commit hash for the specified branch
|
||||
git ls-remote --heads "${repo_url}" "${branch}" 2>/dev/null | head -1 | cut -f1
|
||||
}
|
||||
|
||||
# Function to get the default branch name (master or main)
|
||||
get_default_branch() {
|
||||
local repo_url="$1"
|
||||
|
||||
# Check for master first
|
||||
if git ls-remote --heads "${repo_url}" master &>/dev/null; then
|
||||
echo "master"
|
||||
elif git ls-remote --heads "${repo_url}" main &>/dev/null; then
|
||||
echo "main"
|
||||
else
|
||||
echo "master" # fallback
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to get the latest release tag from a GitHub repo
|
||||
get_latest_release_tag() {
|
||||
local owner_repo="$1"
|
||||
curl -fsSL "https://api.github.com/repos/${owner_repo}/releases/latest" \
|
||||
| grep '"tag_name"' | head -1 | cut -d'"' -f4
|
||||
}
|
||||
|
||||
echo "=========================================="
|
||||
echo "llama-swap-docker Build Script"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Determine commit hashes / release tags - use env vars or auto-detect
|
||||
# For vulkan builds, llama and sd use GitHub release tags (prebuilt binaries).
|
||||
# For cuda builds (or whisper on any backend), use git commit hashes.
|
||||
if [[ -n "${LLAMA_COMMIT_HASH:-}" ]]; then
|
||||
LLAMA_HASH="${LLAMA_COMMIT_HASH}"
|
||||
echo "llama.cpp: Using provided version: ${LLAMA_HASH}"
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
LLAMA_HASH=$(get_latest_release_tag "ggml-org/llama.cpp")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest release tag for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: Auto-detected latest release tag: ${LLAMA_HASH}"
|
||||
else
|
||||
LLAMA_BRANCH=$(get_default_branch "${LLAMA_REPO}")
|
||||
LLAMA_HASH=$(get_latest_commit "${LLAMA_REPO}" "${LLAMA_BRANCH}")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: Auto-detected latest commit (${LLAMA_BRANCH}): ${LLAMA_HASH}"
|
||||
fi
|
||||
|
||||
if [[ -n "${WHISPER_COMMIT_HASH:-}" ]]; then
|
||||
WHISPER_HASH="${WHISPER_COMMIT_HASH}"
|
||||
echo "whisper.cpp: Using provided commit hash: ${WHISPER_HASH}"
|
||||
else
|
||||
WHISPER_BRANCH=$(get_default_branch "${WHISPER_REPO}")
|
||||
WHISPER_HASH=$(get_latest_commit "${WHISPER_REPO}" "${WHISPER_BRANCH}")
|
||||
if [[ -z "${WHISPER_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "whisper.cpp: Auto-detected latest commit (${WHISPER_BRANCH}): ${WHISPER_HASH}"
|
||||
fi
|
||||
|
||||
if [[ -n "${SD_COMMIT_HASH:-}" ]]; then
|
||||
SD_HASH="${SD_COMMIT_HASH}"
|
||||
echo "stable-diffusion.cpp: Using provided version: ${SD_HASH}"
|
||||
elif [[ "$BACKEND" == "vulkan" ]]; then
|
||||
SD_HASH=$(get_latest_release_tag "leejet/stable-diffusion.cpp")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest release tag for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: Auto-detected latest release tag: ${SD_HASH}"
|
||||
else
|
||||
SD_BRANCH=$(get_default_branch "${SD_REPO}")
|
||||
SD_HASH=$(get_latest_commit "${SD_REPO}" "${SD_BRANCH}")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: Auto-detected latest commit (${SD_BRANCH}): ${SD_HASH}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting Docker build..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Build the Docker image with commit hashes as build args
|
||||
# Build context is the repository root (..) so the Dockerfile can access Go source
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
BUILD_ARGS=(
|
||||
--build-arg "BACKEND=${BACKEND}"
|
||||
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
|
||||
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
|
||||
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
|
||||
-t "${DOCKER_IMAGE_TAG}"
|
||||
-f "${SCRIPT_DIR}/${DOCKERFILE}"
|
||||
)
|
||||
|
||||
if [[ "$NO_CACHE" == true ]]; then
|
||||
BUILD_ARGS+=(--no-cache)
|
||||
echo "Note: Building without cache"
|
||||
fi
|
||||
|
||||
# Use docker buildx with a custom builder for parallelism control
|
||||
# The legacy DOCKER_BUILDKIT=1 docker build doesn't respect BUILDKIT_MAX_PARALLELISM env var
|
||||
# We need to use a custom builder with a buildkitd.toml config file
|
||||
BUILDER_NAME="llama-swap-builder"
|
||||
|
||||
# Check if our custom builder exists with the right config, create/update if needed
|
||||
if ! docker buildx inspect "$BUILDER_NAME" >/dev/null 2>&1; then
|
||||
echo "Creating custom buildx builder with max-parallelism=1..."
|
||||
|
||||
# Create buildkitd.toml config file
|
||||
cat > buildkitd.toml << 'BUILDKIT_EOF'
|
||||
[worker.oci]
|
||||
max-parallelism = 1
|
||||
BUILDKIT_EOF
|
||||
|
||||
# Create the builder with the config
|
||||
docker buildx create --name "$BUILDER_NAME" \
|
||||
--driver docker-container \
|
||||
--buildkitd-config buildkitd.toml \
|
||||
--use
|
||||
else
|
||||
# Switch to our builder
|
||||
docker buildx use "$BUILDER_NAME"
|
||||
fi
|
||||
|
||||
echo "Building with sequential stages (one at a time), each using all CPU cores..."
|
||||
echo "Using builder: $BUILDER_NAME"
|
||||
|
||||
# Use docker buildx build with --load to load the image into Docker
|
||||
# The --builder flag ensures we use our custom builder with max-parallelism=1
|
||||
# Build context is the repository root so we can access Go source files
|
||||
docker buildx build --builder "$BUILDER_NAME" --load "${BUILD_ARGS[@]}" "${REPO_ROOT}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verifying build artifacts..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Verify all expected binaries exist in the image
|
||||
MISSING_BINARIES=()
|
||||
|
||||
for binary in llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap; do
|
||||
if ! docker run --rm "${DOCKER_IMAGE_TAG}" which "${binary}" >/dev/null 2>&1; then
|
||||
MISSING_BINARIES+=("${binary}")
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
|
||||
echo "ERROR: Build succeeded but the following binaries are missing from the image:"
|
||||
for binary in "${MISSING_BINARIES[@]}"; do
|
||||
echo " - ${binary}"
|
||||
done
|
||||
echo ""
|
||||
echo "This usually indicates a build stage failure. Try running with --no-cache flag:"
|
||||
echo " ./build-image.sh --vulkan --no-cache"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All expected binaries verified: llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Build complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Image tag: ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Built with:"
|
||||
echo " llama.cpp: ${LLAMA_HASH}"
|
||||
echo " whisper.cpp: ${WHISPER_HASH}"
|
||||
echo " stable-diffusion.cpp: ${SD_HASH}"
|
||||
echo " llama-swap: $(docker run --rm "${DOCKER_IMAGE_TAG}" cat /versions.txt | grep llama-swap | cut -d' ' -f2-)"
|
||||
echo ""
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Note: For AMD GPUs, you may also need to mount render devices:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
|
||||
else
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
|
||||
fi
|
||||
@@ -0,0 +1,33 @@
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
z-image:
|
||||
checkEndpoint: /
|
||||
cmd: |
|
||||
/app/sd-server
|
||||
--listen-port 9999
|
||||
--diffusion-fa
|
||||
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||
--vae /models/ae.safetensors
|
||||
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||
--offload-to-cpu
|
||||
--cfg-scale 1.0
|
||||
--height 512 --width 512
|
||||
--steps 8
|
||||
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
|
||||
@@ -0,0 +1,11 @@
|
||||
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
|
||||
ARG SD_TAG=master-vulkan
|
||||
ARG BASE=llama-swap:latest
|
||||
|
||||
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
|
||||
FROM ${BASE}
|
||||
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
|
||||
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
|
||||
@@ -0,0 +1,44 @@
|
||||
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=170
|
||||
ARG LS_REPO=mostlygeek/llama-swap
|
||||
|
||||
# Set default UID/GID arguments
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
ARG USER_HOME=/app
|
||||
|
||||
# Add user/group
|
||||
ENV HOME=$USER_HOME
|
||||
RUN if [ $UID -ne 0 ]; then \
|
||||
if [ $GID -ne 0 ]; then \
|
||||
groupadd --system --gid $GID app; \
|
||||
fi; \
|
||||
useradd --system --uid $UID --gid $GID \
|
||||
--home $USER_HOME app; \
|
||||
fi
|
||||
|
||||
# Handle paths
|
||||
RUN mkdir --parents $HOME /app
|
||||
RUN chown --recursive $UID:$GID $HOME /app
|
||||
|
||||
# Switch user
|
||||
USER $UID:$GID
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Add /app to PATH
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
RUN \
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
|
||||
|
||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
@@ -0,0 +1,207 @@
|
||||
# Unified multi-stage Dockerfile for AI inference tools
|
||||
# Supports CUDA and Vulkan backends via BACKEND build arg
|
||||
#
|
||||
# Usage:
|
||||
# docker buildx build --build-arg BACKEND=cuda -t llama-swap:unified-cuda .
|
||||
# docker buildx build --build-arg BACKEND=vulkan -t llama-swap:unified-vulkan .
|
||||
# docker buildx build --build-arg BACKEND=cuda --build-arg CMAKE_CUDA_ARCHITECTURES="86;89" -t llama-swap:unified-cuda .
|
||||
#
|
||||
# Each project has its own install script that handles cloning, building,
|
||||
# and installing binaries. Build stages are independent for cache efficiency.
|
||||
|
||||
ARG BACKEND=cuda
|
||||
|
||||
# ── Builder bases ──────────────────────────────────────────────────────
|
||||
|
||||
FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 AS builder-base-cuda
|
||||
|
||||
ARG CMAKE_CUDA_ARCHITECTURES="60;61;75;86;89"
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}
|
||||
ENV CCACHE_DIR=/ccache
|
||||
ENV CCACHE_MAXSIZE=2G
|
||||
ENV PATH="/usr/lib/ccache:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git python3 python3-pip libssl-dev \
|
||||
curl ca-certificates ccache make wget \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# ──
|
||||
|
||||
FROM ubuntu:24.04 AS builder-base-vulkan
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV CCACHE_DIR=/ccache
|
||||
ENV CCACHE_MAXSIZE=2G
|
||||
ENV PATH="/usr/lib/ccache:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git python3 python3-pip libssl-dev \
|
||||
curl ca-certificates ccache make wget software-properties-common \
|
||||
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
|
||||
spirv-headers \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# ── Select builder base by BACKEND ────────────────────────────────────
|
||||
|
||||
FROM builder-base-${BACKEND} AS builder-base
|
||||
|
||||
# ── Build whisper.cpp (fastest build, run first) ──────────────────────
|
||||
|
||||
FROM builder-base AS whisper-build
|
||||
ARG BACKEND=cuda
|
||||
ARG WHISPER_COMMIT_HASH=master
|
||||
COPY install-whisper.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=whisper-${BACKEND},target=/src/whisper.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-whisper.sh "${WHISPER_COMMIT_HASH}"
|
||||
|
||||
# ── Build stable-diffusion.cpp ────────────────────────────────────────
|
||||
|
||||
FROM builder-base AS sd-build
|
||||
ARG BACKEND=cuda
|
||||
ARG SD_COMMIT_HASH=master
|
||||
COPY install-sd.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=sd-${BACKEND},target=/src/stable-diffusion.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-sd.sh "${SD_COMMIT_HASH}"
|
||||
|
||||
# ── Build llama.cpp (slowest build, run last) ─────────────────────────
|
||||
|
||||
FROM builder-base AS llama-build
|
||||
ARG BACKEND=cuda
|
||||
ARG LLAMA_COMMIT_HASH=master
|
||||
COPY install-llama.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-${BACKEND},target=/ccache \
|
||||
--mount=type=cache,id=llama-${BACKEND},target=/src/llama.cpp/build \
|
||||
BACKEND=${BACKEND} bash /build/install-llama.sh "${LLAMA_COMMIT_HASH}"
|
||||
|
||||
# ── Build ik_llama.cpp (CUDA only) ────────────────────────────────────
|
||||
#
|
||||
# Two named stages allow ARG BACKEND to select at build time:
|
||||
# - ik-llama-cuda : real build (from builder-base-cuda)
|
||||
# - ik-llama-vulkan: no-op (empty /install/bin, skips CUDA pull entirely)
|
||||
# BuildKit only evaluates the selected branch, so vulkan builds never
|
||||
# pull nvidia/cuda:*-devel or compile ik_llama.cpp.
|
||||
|
||||
FROM builder-base-vulkan AS ik-llama-vulkan
|
||||
RUN mkdir -p /install/bin
|
||||
|
||||
FROM builder-base-cuda AS ik-llama-cuda
|
||||
ARG IK_LLAMA_COMMIT_HASH=main
|
||||
COPY install-ik-llama.sh /build/
|
||||
RUN --mount=type=cache,id=ccache-cuda,target=/ccache \
|
||||
--mount=type=cache,id=ik-llama-cuda,target=/src/ik_llama.cpp/build \
|
||||
bash /build/install-ik-llama.sh "${IK_LLAMA_COMMIT_HASH}"
|
||||
|
||||
ARG BACKEND=cuda
|
||||
FROM ik-llama-${BACKEND} AS ik-llama-build
|
||||
|
||||
# ── Download llama-swap release binary ────────────────────────────────
|
||||
|
||||
FROM builder-base AS llama-swap-download
|
||||
ARG LS_VERSION=latest
|
||||
COPY install-llama-swap.sh /build/
|
||||
RUN bash /build/install-llama-swap.sh "${LS_VERSION}"
|
||||
|
||||
# ── Runtime bases ─────────────────────────────────────────────────────
|
||||
|
||||
FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04 AS runtime-cuda
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
|
||||
ENV PATH="/usr/local/bin:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgomp1 python3 curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# CUDA stub drivers for container compatibility
|
||||
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so
|
||||
COPY --from=builder-base-cuda /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
|
||||
|
||||
# ──
|
||||
|
||||
FROM ubuntu:24.04 AS runtime-vulkan
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH="/usr/local/bin:${PATH}"
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgomp1 libvulkan1 mesa-vulkan-drivers \
|
||||
python3 curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ── Select runtime base by BACKEND ────────────────────────────────────
|
||||
|
||||
FROM runtime-${BACKEND} AS runtime
|
||||
|
||||
ARG BACKEND=cuda
|
||||
ARG LLAMA_COMMIT_HASH=unknown
|
||||
ARG WHISPER_COMMIT_HASH=unknown
|
||||
ARG SD_COMMIT_HASH=unknown
|
||||
ARG IK_LLAMA_COMMIT_HASH=unknown
|
||||
ARG RUN_UID=0
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3-numpy python3-sentencepiece python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user when RUN_UID != 0
|
||||
RUN if [ "$RUN_UID" != "0" ]; then \
|
||||
groupadd --system --gid $RUN_UID llama-swap && \
|
||||
useradd --system --uid $RUN_UID --gid $RUN_UID \
|
||||
--home /app --shell /sbin/nologin llama-swap; \
|
||||
fi && \
|
||||
mkdir -p /etc/llama-swap/config && \
|
||||
chown -R ${RUN_UID}:${RUN_UID} /etc/llama-swap
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy whisper.cpp binaries and libraries
|
||||
COPY --from=whisper-build /install/bin/whisper-server /usr/local/bin/
|
||||
COPY --from=whisper-build /install/bin/whisper-cli /usr/local/bin/
|
||||
COPY --from=whisper-build /install/lib/ /usr/local/lib/
|
||||
|
||||
# Copy stable-diffusion.cpp binaries and libraries
|
||||
COPY --from=sd-build /install/bin/sd-server /usr/local/bin/
|
||||
COPY --from=sd-build /install/bin/sd-cli /usr/local/bin/
|
||||
COPY --from=sd-build /install/lib/ /usr/local/lib/
|
||||
|
||||
# Copy llama.cpp binaries (statically linked)
|
||||
COPY --from=llama-build /install/bin/llama-server /usr/local/bin/
|
||||
COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
|
||||
|
||||
# Copy ik-llama-server (CUDA only; empty copy for vulkan)
|
||||
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
|
||||
|
||||
# Install uv
|
||||
RUN pip install uv --break-system-packages
|
||||
|
||||
# Copy llama-swap binary
|
||||
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
|
||||
COPY --from=llama-swap-download /install/llama-swap-version /tmp/
|
||||
|
||||
RUN ldconfig
|
||||
|
||||
COPY config.example.yaml /etc/llama-swap/config/config.yaml
|
||||
|
||||
# Version tracking
|
||||
RUN echo "llama.cpp: ${LLAMA_COMMIT_HASH}" > /versions.txt && \
|
||||
echo "whisper.cpp: ${WHISPER_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "stable-diffusion.cpp: ${SD_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "ik_llama.cpp: ${IK_LLAMA_COMMIT_HASH}" >> /versions.txt && \
|
||||
echo "llama-swap: $(cat /tmp/llama-swap-version)" >> /versions.txt && \
|
||||
echo "backend: ${BACKEND}" >> /versions.txt && \
|
||||
echo "build_timestamp: $(date -u +%Y-%m-%dT%H:%M:%SZ)" >> /versions.txt
|
||||
|
||||
RUN mkdir -p /models && chown ${RUN_UID}:${RUN_UID} /models
|
||||
WORKDIR /models
|
||||
USER ${RUN_UID}
|
||||
ENTRYPOINT ["llama-swap"]
|
||||
CMD ["-config", "/etc/llama-swap/config/config.yaml", "-listen", "0.0.0.0:8080"]
|
||||
@@ -0,0 +1,8 @@
|
||||
# Unified Docker Container
|
||||
|
||||
These scripts create a custom llama-swap container that contains:
|
||||
|
||||
- llama-server for LLMs, rerank and embedding model support
|
||||
- sd-server (stable-diffusion.cpp) for image generation
|
||||
- whisper.cpp for ASR
|
||||
|
||||
Executable
+303
@@ -0,0 +1,303 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Build script for unified container with version pinning
|
||||
#
|
||||
# Usage:
|
||||
# ./build-image.sh --cuda # Build CUDA image
|
||||
# ./build-image.sh --vulkan # Build Vulkan image
|
||||
# ./build-image.sh --cuda --no-cache # Build without cache
|
||||
# LLAMA_REF=b1234 ./build-image.sh --vulkan # Pin llama.cpp to a commit hash
|
||||
# LLAMA_REF=v1.2.3 ./build-image.sh --cuda # Pin llama.cpp to a tag
|
||||
# WHISPER_REF=v1.0.0 ./build-image.sh --vulkan # Pin whisper.cpp to a tag
|
||||
# SD_REF=master ./build-image.sh --cuda # Pin stable-diffusion.cpp to a branch
|
||||
# LS_VERSION=170 ./build-image.sh --cuda # Override llama-swap version
|
||||
# IK_LLAMA_REF=main ./build-image.sh --cuda # Pin ik_llama.cpp to main branch (CUDA only)
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BACKEND=""
|
||||
NO_CACHE=false
|
||||
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--cuda)
|
||||
BACKEND="cuda"
|
||||
;;
|
||||
--vulkan)
|
||||
BACKEND="vulkan"
|
||||
;;
|
||||
--no-cache)
|
||||
NO_CACHE=true
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --cuda Build CUDA image (NVIDIA GPUs)"
|
||||
echo " --vulkan Build Vulkan image (AMD GPUs and compatible hardware)"
|
||||
echo " --no-cache Force rebuild without using Docker cache"
|
||||
echo " --help, -h Show this help message"
|
||||
echo ""
|
||||
echo "Environment variables:"
|
||||
echo " DOCKER_IMAGE_TAG Set custom image tag (default: llama-swap:unified-cuda or llama-swap:unified-vulkan)"
|
||||
echo " LLAMA_REF Pin llama.cpp to a commit, tag, or branch"
|
||||
echo " WHISPER_REF Pin whisper.cpp to a commit, tag, or branch"
|
||||
echo " SD_REF Pin stable-diffusion.cpp to a commit, tag, or branch"
|
||||
echo " IK_LLAMA_REF Pin ik_llama.cpp to a commit, tag, or branch (CUDA only)"
|
||||
echo " LS_VERSION Override llama-swap version (e.g., '170' or 'latest')"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$BACKEND" ]]; then
|
||||
echo "Error: No backend specified. Please use --cuda or --vulkan."
|
||||
echo ""
|
||||
echo "Usage: ./build-image.sh --cuda|--vulkan [--no-cache]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DOCKER_IMAGE_TAG="${DOCKER_IMAGE_TAG:-llama-swap:unified-${BACKEND}}"
|
||||
|
||||
# Git repository URLs
|
||||
LLAMA_REPO="https://github.com/ggml-org/llama.cpp.git"
|
||||
WHISPER_REPO="https://github.com/ggml-org/whisper.cpp.git"
|
||||
SD_REPO="https://github.com/leejet/stable-diffusion.cpp.git"
|
||||
LLAMA_SWAP_REPO="https://github.com/mostlygeek/llama-swap.git"
|
||||
IK_LLAMA_REPO="https://github.com/ikawrakow/ik_llama.cpp.git"
|
||||
|
||||
# Resolve a git ref (commit hash, tag, or branch) to a full commit hash.
|
||||
# Requires only: git, network access to the remote.
|
||||
resolve_ref() {
|
||||
local repo_url="$1"
|
||||
local ref="$2"
|
||||
|
||||
# Full 40-char SHA — use as-is
|
||||
if [[ "${ref}" =~ ^[0-9a-f]{40}$ ]]; then
|
||||
echo "${ref}"
|
||||
return
|
||||
fi
|
||||
|
||||
# Try tag then branch (exact match)
|
||||
local hash
|
||||
hash=$(git ls-remote "${repo_url}" "refs/tags/${ref}" "refs/heads/${ref}" 2>/dev/null | head -1 | cut -f1)
|
||||
if [[ -n "${hash}" ]]; then
|
||||
echo "${hash}"
|
||||
return
|
||||
fi
|
||||
|
||||
# Short hash (7+ chars): scan all refs for a SHA with this prefix
|
||||
if [[ "${ref}" =~ ^[0-9a-f]{7,}$ ]]; then
|
||||
hash=$(git ls-remote "${repo_url}" 2>/dev/null | grep "^${ref}" | head -1 | cut -f1)
|
||||
if [[ -n "${hash}" ]]; then
|
||||
echo "${hash}"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "ERROR: Could not resolve ref '${ref}' for ${repo_url}" >&2
|
||||
if [[ "${ref}" =~ ^[0-9a-f]+$ && ${#ref} -lt 7 ]]; then
|
||||
echo " Short hashes must be at least 7 characters (got ${#ref})." >&2
|
||||
else
|
||||
echo " Tried: tag, branch, git ls-remote prefix match" >&2
|
||||
fi
|
||||
echo " Use a full 40-char SHA, a tag name, a branch name, or a 7+ char short hash." >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Resolve HEAD of a repo without needing to know the default branch name.
|
||||
get_latest_hash() {
|
||||
git ls-remote "${1}" HEAD 2>/dev/null | head -1 | cut -f1
|
||||
}
|
||||
|
||||
echo "=========================================="
|
||||
echo "llama-swap Unified Build (${BACKEND})"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Resolve llama.cpp ref
|
||||
if [[ -n "${LLAMA_REF:-}" ]]; then
|
||||
LLAMA_HASH=$(resolve_ref "${LLAMA_REPO}" "${LLAMA_REF}") || exit 1
|
||||
echo "llama.cpp: ${LLAMA_REF} -> ${LLAMA_HASH}"
|
||||
else
|
||||
LLAMA_HASH=$(get_latest_hash "${LLAMA_REPO}")
|
||||
if [[ -z "${LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama.cpp: latest HEAD: ${LLAMA_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve whisper.cpp ref
|
||||
if [[ -n "${WHISPER_REF:-}" ]]; then
|
||||
WHISPER_HASH=$(resolve_ref "${WHISPER_REPO}" "${WHISPER_REF}") || exit 1
|
||||
echo "whisper.cpp: ${WHISPER_REF} -> ${WHISPER_HASH}"
|
||||
else
|
||||
WHISPER_HASH=$(get_latest_hash "${WHISPER_REPO}")
|
||||
if [[ -z "${WHISPER_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for whisper.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "whisper.cpp: latest HEAD: ${WHISPER_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve stable-diffusion.cpp ref
|
||||
if [[ -n "${SD_REF:-}" ]]; then
|
||||
SD_HASH=$(resolve_ref "${SD_REPO}" "${SD_REF}") || exit 1
|
||||
echo "stable-diffusion.cpp: ${SD_REF} -> ${SD_HASH}"
|
||||
else
|
||||
SD_HASH=$(get_latest_hash "${SD_REPO}")
|
||||
if [[ -z "${SD_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for stable-diffusion.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "stable-diffusion.cpp: latest HEAD: ${SD_HASH}"
|
||||
fi
|
||||
|
||||
# Resolve ik_llama.cpp ref (CUDA only)
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
if [[ -n "${IK_LLAMA_REF:-}" ]]; then
|
||||
IK_LLAMA_HASH=$(resolve_ref "${IK_LLAMA_REPO}" "${IK_LLAMA_REF}") || exit 1
|
||||
echo "ik_llama.cpp: ${IK_LLAMA_REF} -> ${IK_LLAMA_HASH}"
|
||||
else
|
||||
IK_LLAMA_HASH=$(get_latest_hash "${IK_LLAMA_REPO}")
|
||||
if [[ -z "${IK_LLAMA_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for ik_llama.cpp" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "ik_llama.cpp: latest HEAD: ${IK_LLAMA_HASH}"
|
||||
fi
|
||||
else
|
||||
IK_LLAMA_HASH="n/a"
|
||||
echo "ik_llama.cpp: skipped (vulkan build)"
|
||||
fi
|
||||
|
||||
# Resolve llama-swap ref
|
||||
if [[ -n "${LS_VERSION:-}" ]]; then
|
||||
LS_HASH=$(resolve_ref "${LLAMA_SWAP_REPO}" "${LS_VERSION}") || exit 1
|
||||
echo "llama-swap: ${LS_VERSION} -> ${LS_HASH}"
|
||||
else
|
||||
LS_HASH=$(get_latest_hash "${LLAMA_SWAP_REPO}")
|
||||
if [[ -z "${LS_HASH}" ]]; then
|
||||
echo "ERROR: Could not determine latest commit for llama-swap" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "llama-swap: latest HEAD: ${LS_HASH}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting Docker build..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
BUILD_ARGS=(
|
||||
--build-arg "BACKEND=${BACKEND}"
|
||||
--build-arg "LLAMA_COMMIT_HASH=${LLAMA_HASH}"
|
||||
--build-arg "WHISPER_COMMIT_HASH=${WHISPER_HASH}"
|
||||
--build-arg "SD_COMMIT_HASH=${SD_HASH}"
|
||||
--build-arg "IK_LLAMA_COMMIT_HASH=${IK_LLAMA_HASH}"
|
||||
--build-arg "LS_VERSION=${LS_HASH}"
|
||||
-t "${DOCKER_IMAGE_TAG}"
|
||||
-f "${SCRIPT_DIR}/Dockerfile"
|
||||
)
|
||||
|
||||
if [[ "$NO_CACHE" == true ]]; then
|
||||
BUILD_ARGS+=(--no-cache)
|
||||
echo "Note: Building without cache"
|
||||
elif [[ "${GITHUB_ACTIONS:-}" == "true" && "${ACT:-}" != "true" ]]; then
|
||||
CACHE_REF="ghcr.io/mostlygeek/llama-swap:unified-${BACKEND}-cache"
|
||||
BUILD_ARGS+=(
|
||||
--cache-from "type=registry,ref=${CACHE_REF}"
|
||||
--cache-to "type=registry,ref=${CACHE_REF},mode=max"
|
||||
)
|
||||
echo "Note: Using registry cache (${CACHE_REF})"
|
||||
fi
|
||||
|
||||
DOCKER_BUILDKIT=1 docker buildx build --load "${BUILD_ARGS[@]}" "${SCRIPT_DIR}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verifying build artifacts..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
EXPECTED_BINARIES=(llama-server llama-cli whisper-server whisper-cli sd-server sd-cli llama-swap)
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
EXPECTED_BINARIES+=(ik-llama-server)
|
||||
fi
|
||||
|
||||
MISSING_BINARIES=()
|
||||
for binary in "${EXPECTED_BINARIES[@]}"; do
|
||||
if ! docker run --rm --entrypoint which "${DOCKER_IMAGE_TAG}" "${binary}" >/dev/null 2>&1; then
|
||||
MISSING_BINARIES+=("${binary}")
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ ${#MISSING_BINARIES[@]} -gt 0 ]]; then
|
||||
echo "ERROR: Build succeeded but the following binaries are missing:"
|
||||
for binary in "${MISSING_BINARIES[@]}"; do
|
||||
echo " - ${binary}"
|
||||
done
|
||||
echo ""
|
||||
echo "Try running with --no-cache flag:"
|
||||
echo " ./build-image.sh --${BACKEND} --no-cache"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
VERIFIED_LIST="llama-server, llama-cli, whisper-server, whisper-cli, sd-server, sd-cli, llama-swap"
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
VERIFIED_LIST="${VERIFIED_LIST}, ik-llama-server"
|
||||
fi
|
||||
echo "All expected binaries verified: ${VERIFIED_LIST}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Building rootless image..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
ROOTLESS_TAG="${DOCKER_IMAGE_TAG}-rootless"
|
||||
docker buildx build --load -t "${ROOTLESS_TAG}" - <<EOF
|
||||
FROM ${DOCKER_IMAGE_TAG}
|
||||
USER root
|
||||
RUN groupadd --system --gid 10001 llama-swap && \\
|
||||
useradd --system --uid 10001 --gid 10001 \\
|
||||
--home /app --shell /sbin/nologin llama-swap && \\
|
||||
chown -R 10001:10001 /etc/llama-swap /models
|
||||
USER 10001
|
||||
EOF
|
||||
|
||||
echo "Rootless image built: ${ROOTLESS_TAG}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Build complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Image tags:"
|
||||
echo " ${DOCKER_IMAGE_TAG}"
|
||||
echo " ${ROOTLESS_TAG}"
|
||||
echo ""
|
||||
echo "Built with:"
|
||||
echo " llama.cpp: ${LLAMA_HASH}"
|
||||
echo " whisper.cpp: ${WHISPER_HASH}"
|
||||
echo " stable-diffusion.cpp: ${SD_HASH}"
|
||||
if [[ "$BACKEND" == "cuda" ]]; then
|
||||
echo " ik_llama.cpp: ${IK_LLAMA_HASH}"
|
||||
fi
|
||||
echo " llama-swap: $(docker run --rm --entrypoint cat "${DOCKER_IMAGE_TAG}" /versions.txt | grep llama-swap | cut -d' ' -f2-)"
|
||||
echo ""
|
||||
if [[ "$BACKEND" == "vulkan" ]]; then
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri ${DOCKER_IMAGE_TAG}"
|
||||
echo ""
|
||||
echo "Note: For AMD GPUs, you may also need:"
|
||||
echo " docker run -it --rm --device /dev/dri:/dev/dri --group-add video ${DOCKER_IMAGE_TAG}"
|
||||
else
|
||||
echo "Run with:"
|
||||
echo " docker run -it --rm --gpus all ${DOCKER_IMAGE_TAG}"
|
||||
fi
|
||||
@@ -0,0 +1,33 @@
|
||||
# placeholder example configuration
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port ${PORT}
|
||||
|
||||
"whisper":
|
||||
checkEndpoint: /v1/audio/transcriptions/
|
||||
cmd: >
|
||||
whisper-server
|
||||
--port ${PORT}
|
||||
--m /models/whisper.bin
|
||||
--flash-attn
|
||||
--request-path /v1/audio/transcriptions --inference-path ""
|
||||
|
||||
"image":
|
||||
checkEndpoint: /
|
||||
cmd: |
|
||||
/app/sd-server
|
||||
--listen-port 9999
|
||||
--diffusion-fa
|
||||
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||
--vae /models/ae.safetensors
|
||||
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||
--offload-to-cpu
|
||||
--cfg-scale 1.0
|
||||
--height 512 --width 512
|
||||
--steps 8
|
||||
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Install ik_llama.cpp - clone, build, and install binaries
|
||||
# Usage: ./install-ik-llama.sh <commit_hash>
|
||||
# Note: CUDA only; always built against builder-base-cuda
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-main}"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted build dir doesn't break clone)
|
||||
echo "=== Cloning ik_llama.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/ik_llama.cpp
|
||||
cd /src/ik_llama.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ikawrakow/ik_llama.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DGGML_CUDA=ON
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda -Wl,--allow-shlib-undefined"
|
||||
)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building ik_llama.cpp ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target llama-server
|
||||
|
||||
if [ ! -f "build/bin/llama-server" ]; then
|
||||
echo "FATAL: llama-server not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install as ik-llama-server to avoid collision with llama.cpp's llama-server
|
||||
cp "build/bin/llama-server" "/install/bin/ik-llama-server"
|
||||
echo "=== ik_llama.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
Executable
+67
@@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
# Install llama-swap - download latest release binary from GitHub
|
||||
# Usage: ./install-llama-swap.sh [version]
|
||||
# version: release version number (e.g., "170") or "latest" (default)
|
||||
set -e
|
||||
|
||||
VERSION="${1:-latest}"
|
||||
REPO="mostlygeek/llama-swap"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# If a full commit hash is given, find the release tag that points to it
|
||||
if echo "${VERSION}" | grep -qE '^[0-9a-f]{40}$'; then
|
||||
echo "=== Resolving commit ${VERSION:0:7} to release tag ==="
|
||||
TAG=$(git ls-remote --tags "https://github.com/${REPO}.git" 2>/dev/null \
|
||||
| grep "^${VERSION}" | sed 's|.*refs/tags/||' | grep -v '\^{}' | head -1)
|
||||
if [ -n "${TAG}" ]; then
|
||||
echo "Resolved to tag: ${TAG}"
|
||||
VERSION="${TAG#v}"
|
||||
else
|
||||
echo "No release tag found for commit ${VERSION:0:7}, using latest"
|
||||
VERSION="latest"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Strip leading 'v' prefix so both "198" and "v198" work
|
||||
VERSION="${VERSION#v}"
|
||||
|
||||
# Resolve "latest" to actual version number
|
||||
if [ "$VERSION" = "latest" ]; then
|
||||
echo "=== Resolving latest llama-swap release ==="
|
||||
VERSION=$(curl -fsSL "https://api.github.com/repos/${REPO}/releases/latest" \
|
||||
| grep '"tag_name"' | head -1 | cut -d'"' -f4 | sed 's/^v//')
|
||||
if [ -z "$VERSION" ]; then
|
||||
echo "FATAL: Could not determine latest release version" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Latest version: ${VERSION}"
|
||||
fi
|
||||
|
||||
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
aarch64|arm64) ARCH="arm64" ;;
|
||||
*) echo "FATAL: Unsupported architecture: $ARCH" >&2; exit 1 ;;
|
||||
esac
|
||||
|
||||
# Download and extract
|
||||
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_${ARCH}.tar.gz"
|
||||
echo "=== Downloading llama-swap v${VERSION} ==="
|
||||
echo "URL: $URL"
|
||||
curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
|
||||
tar -xzf /tmp/llama-swap.tar.gz -C /install/bin/
|
||||
rm /tmp/llama-swap.tar.gz
|
||||
|
||||
# Validate
|
||||
if [ ! -x "/install/bin/llama-swap" ]; then
|
||||
echo "FATAL: llama-swap binary not found or not executable" >&2
|
||||
ls -la /install/bin/ >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "$VERSION" > /install/llama-swap-version
|
||||
|
||||
echo "=== llama-swap v${VERSION} installed ==="
|
||||
ls -la /install/bin/llama-swap
|
||||
Executable
+63
@@ -0,0 +1,63 @@
|
||||
#!/bin/bash
|
||||
# Install llama.cpp - clone, build, and install binaries
|
||||
# Usage: BACKEND=cuda|vulkan ./install-llama.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/llama.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning llama.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/llama.cpp
|
||||
cd /src/llama.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ggml-org/llama.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DLLAMA_BUILD_TESTS=OFF
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(llama-cli llama-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building llama.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in "${TARGETS[@]}"; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
echo "=== llama.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
Executable
+68
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
# Install stable-diffusion.cpp - clone, build, and install binaries and library
|
||||
# Usage: BACKEND=cuda|vulkan ./install-sd.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin /install/lib
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/stable-diffusion.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning stable-diffusion.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/stable-diffusion.cpp
|
||||
cd /src/stable-diffusion.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/leejet/stable-diffusion.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
git submodule update --init --recursive --depth=1
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
-DSD_BUILD_EXAMPLES=ON
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
-DSD_CUDA=ON
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
-DSD_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(stable-diffusion sd-cli sd-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building stable-diffusion.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in sd-cli sd-server; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
|
||||
|
||||
echo "=== stable-diffusion.cpp build complete ==="
|
||||
ls -la /install/bin/ /install/lib/
|
||||
Executable
+64
@@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
# Install whisper.cpp - clone, build, and install binaries
|
||||
# Usage: BACKEND=cuda|vulkan ./install-whisper.sh <commit_hash>
|
||||
set -e
|
||||
|
||||
COMMIT_HASH="${1:-master}"
|
||||
BACKEND="${BACKEND:-cuda}"
|
||||
|
||||
mkdir -p /install/bin /install/lib
|
||||
|
||||
# Clone and checkout (init-based so cache-mounted /src/whisper.cpp/build dir doesn't break clone)
|
||||
echo "=== Cloning whisper.cpp at ${COMMIT_HASH} ==="
|
||||
mkdir -p /src/whisper.cpp
|
||||
cd /src/whisper.cpp
|
||||
if [ ! -d .git ]; then
|
||||
git init
|
||||
git remote add origin https://github.com/ggml-org/whisper.cpp.git
|
||||
fi
|
||||
git fetch --depth=1 origin "${COMMIT_HASH}"
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# Common cmake flags
|
||||
CMAKE_FLAGS=(
|
||||
-DGGML_NATIVE=OFF
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
)
|
||||
|
||||
if [ "$BACKEND" = "cuda" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_VULKAN=OFF
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES:?CMAKE_CUDA_ARCHITECTURES must be set}"
|
||||
"-DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler"
|
||||
"-DCMAKE_EXE_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-Wl,-rpath-link,/usr/local/cuda/lib64/stubs -lcuda"
|
||||
)
|
||||
elif [ "$BACKEND" = "vulkan" ]; then
|
||||
CMAKE_FLAGS+=(
|
||||
-DGGML_CUDA=OFF
|
||||
-DGGML_VULKAN=ON
|
||||
)
|
||||
fi
|
||||
|
||||
TARGETS=(whisper-cli whisper-server)
|
||||
|
||||
rm -rf build/CMakeCache.txt build/CMakeFiles 2>/dev/null || true
|
||||
|
||||
echo "=== Building whisper.cpp for ${BACKEND} ==="
|
||||
cmake -B build "${CMAKE_FLAGS[@]}"
|
||||
cmake --build build --config Release -j"$(nproc)" --target "${TARGETS[@]}"
|
||||
|
||||
for bin in "${TARGETS[@]}"; do
|
||||
if [ ! -f "build/bin/$bin" ]; then
|
||||
echo "FATAL: $bin not found in build/bin/" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "build/bin/$bin" "/install/bin/"
|
||||
done
|
||||
find build -name "*.so*" -type f -exec cp {} /install/lib/ \;
|
||||
|
||||
echo "=== whisper.cpp build complete ==="
|
||||
ls -la /install/bin/
|
||||
|
Before Width: | Height: | Size: 261 KiB After Width: | Height: | Size: 261 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 351 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 198 KiB |
@@ -0,0 +1,582 @@
|
||||
# config.yaml
|
||||
|
||||
llama-swap is designed to be very simple: one binary, one configuration file.
|
||||
|
||||
## minimal viable config
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
|
||||
```
|
||||
|
||||
This is enough to launch `llama-server` to serve `model1`. Of course, llama-swap is about making it possible to serve many models:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/model.gguf
|
||||
model2:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/another_model.gguf
|
||||
model3:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
|
||||
```
|
||||
|
||||
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model which is useful if you want to run multiple models at the same time with the `matrix` feature.
|
||||
|
||||
## Advanced control with `cmd`
|
||||
|
||||
llama-swap is also about customizability. You can use any CLI flag available:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: | # support for multi-line
|
||||
llama-server --PORT ${PORT} -m /path/to/model.gguf
|
||||
--ctx-size 8192
|
||||
--jinja
|
||||
--cache-type-k q8_0
|
||||
--cache-type-v q8_0
|
||||
```
|
||||
|
||||
## Support for any OpenAI API compatible server
|
||||
|
||||
llama-swap supports any OpenAI API compatible server. If you can run it on the CLI llama-swap will be able to manage it. Even if it's run in Docker or Podman containers.
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"Q3-30B-CODER-VLLM":
|
||||
name: "Qwen3 30B Coder vllm AWQ (Q3-30B-CODER-VLLM)"
|
||||
# cmdStop provides a reliable way to stop containers
|
||||
cmdStop: docker stop vllm-coder
|
||||
cmd: |
|
||||
docker run --init --rm --name vllm-coder
|
||||
--runtime=nvidia --gpus '"device=2,3"'
|
||||
--shm-size=16g
|
||||
-v /mnt/nvme/vllm-cache:/root/.cache
|
||||
-v /mnt/ssd-extra/models:/models -p ${PORT}:8000
|
||||
vllm/vllm-openai:v0.10.0
|
||||
--model "/models/cpatonn/Qwen3-Coder-30B-A3B-Instruct-AWQ"
|
||||
--served-model-name "Q3-30B-CODER-VLLM"
|
||||
--enable-expert-parallel
|
||||
--swap-space 16
|
||||
--max-num-seqs 512
|
||||
--max-model-len 65536
|
||||
--max-seq-len-to-capture 65536
|
||||
--gpu-memory-utilization 0.9
|
||||
--tensor-parallel-size 2
|
||||
--trust-remote-code
|
||||
```
|
||||
|
||||
## Many more features..
|
||||
|
||||
llama-swap supports many more features to customize how you want to manage your environment.
|
||||
|
||||
| Feature | Description |
|
||||
| --------- | ---------------------------------------------- |
|
||||
| `ttl` | automatic unloading of models after a timeout |
|
||||
| `macros` | reusable snippets to use in configurations |
|
||||
| `matrix` | run multiple models at a time |
|
||||
| `hooks` | event driven functionality |
|
||||
| `env` | define environment variables per model |
|
||||
| `aliases` | serve a model with different names |
|
||||
| `filters` | modify requests before sending to the upstream |
|
||||
| `...` | And many more tweaks |
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
> [!NOTE]
|
||||
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
|
||||
|
||||
```yaml
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
# 💡 Tip - Use an LLM with this file!
|
||||
# ====================================
|
||||
# This example configuration is written to be LLM friendly. Try
|
||||
# copying this file into an LLM and asking it to explain or generate
|
||||
# sections for you.
|
||||
# ====================================
|
||||
|
||||
# Usage notes:
|
||||
# - Below are all the available configuration options for llama-swap.
|
||||
# - Settings noted as "required" must be in your configuration file
|
||||
# - Settings noted as "optional" can be omitted
|
||||
|
||||
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||
# - optional, default: 120
|
||||
# - minimum value is 15 seconds, anything less will be set to this value
|
||||
healthCheckTimeout: 500
|
||||
|
||||
# logLevel: sets the logging value
|
||||
# - optional, default: info
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||
# - optional, default: 10
|
||||
# - set to 0 to disable
|
||||
captureBuffer: 15
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||
# field
|
||||
# - optional, default: false
|
||||
# - when true, a stream of loading messages will be sent to the client in the
|
||||
# reasoning field so chat UIs can show that loading is in progress.
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# globalTTL: the default TTL in seconds before unloading a model
|
||||
# - optional, default: 0 (never automatically unload)
|
||||
# - must be >= 0
|
||||
globalTTL: 0
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||
# - useful for reducing common configuration settings
|
||||
# - macro names are strings and must be less than 64 characters
|
||||
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||
# - env macros are substituted first, before regular macros
|
||||
# - if the env var is not set, config loading will fail with an error
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
/path/to/llama-server/llama-server-ec9e0301
|
||||
--port ${PORT}
|
||||
|
||||
"default_ctx": 4096
|
||||
|
||||
# Example of macro-in-macro usage. macros can contain other macros
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# Example of environment variable macros
|
||||
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||
# - useful for paths, secrets, or machine-specific configuration
|
||||
"models_dir": "${env.HOME}/models"
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
|
||||
# use environment variable macros to keep secrets out of the config
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
# - model settings have default values that are used if they are not defined here
|
||||
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||
# - below are examples of the all the settings a model can have
|
||||
models:
|
||||
# keys are the model names used in API requests
|
||||
"gpt-oss-120b":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
# - model level macros follow the same rules as global macros
|
||||
macros:
|
||||
"default_ctx": 16384
|
||||
"temp": 0.7
|
||||
|
||||
# cmd: the command to run to start the inference server.
|
||||
# - required
|
||||
# - it is just a string, similar to what you would run on the CLI
|
||||
# - using `|` allows for comments in the command, these will be parsed out
|
||||
# - macros can be used within cmd
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/gpt-oss-120B.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "gpt-oss 120B"
|
||||
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A thinking model from OpenAI"
|
||||
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
# - each value is a single string
|
||||
# - in the format: ENV_NAME=value
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||
|
||||
# proxy: the URL where llama-swap routes API requests
|
||||
# - optional, default: http://localhost:${PORT}
|
||||
# - if you used ${PORT} in cmd this can be omitted
|
||||
# - if you use a custom port in cmd this *must* be set
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
# - all requests wait until the endpoint is ready or fails
|
||||
# - use "none" to skip endpoint health checking
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: -1 (use global default)
|
||||
# - ttl values must be a value greater than or equal to 0
|
||||
# - a ttl of -1 will use the global TTL value as the default
|
||||
# - a ttl of 0 will mean never unload
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
# useModelName: override the model name that is sent to upstream server
|
||||
# - optional, default: ""
|
||||
# - useful for when the upstream server expects a specific model name that
|
||||
# is different from the model's ID
|
||||
useModelName: "openai/gpt-oss-120B"
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as peer filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for server side enforcement of sampling parameters
|
||||
# - the `model` parameter can never be removed
|
||||
# - can be any JSON key in the request body
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for enforcing specific parameter values
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - always runs for the model
|
||||
setParams:
|
||||
# Example: enforce specific sampling parameters
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
|
||||
# setParamsByID: a dictionary of parameters to set based the model ID
|
||||
# - optional, default: empty dictionary
|
||||
# - combine with aliases to create variant behaviour without reloading the model
|
||||
# - parameters are set in the request body JSON
|
||||
# - run after setParams so it will override any settings
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - model aliases will be automatically created for each key
|
||||
setParamsByID:
|
||||
"${MODEL_ID}":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: medium
|
||||
"${MODEL_ID}:high":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: low
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
# - metadata is only passed through in /v1/models responses
|
||||
metadata:
|
||||
# port will remain an integer
|
||||
port: ${PORT}
|
||||
|
||||
# the ${temp} macro will remain a float
|
||||
temperature: ${temp}
|
||||
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||
|
||||
a_list:
|
||||
- 1
|
||||
- 1.23
|
||||
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||
|
||||
an_obj:
|
||||
a: "1"
|
||||
b: 2
|
||||
# objects can contain complex types with macro substitution
|
||||
# becomes: c: [0.7, false, "model: llama"]
|
||||
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||
|
||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||
# - optional, default: 0
|
||||
# - useful for limiting the number of active parallel requests a model can process
|
||||
# - must be set per model
|
||||
# - any number greater than 0 will override the internal default value of 10
|
||||
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# timeouts: configure proxy connection timeouts for this model
|
||||
# - optional, defaults shown below
|
||||
# - useful for models running on slower hardware that need longer timeouts
|
||||
# - connect: TCP dial connection timeout in seconds, default: 30 seconds
|
||||
# - keepalive: TCP connection keepalive timeout, default: 30 seconds
|
||||
# - responseHeader: time to wait for response headers in seconds, default: 0 (no timeout)
|
||||
# - tlsHandshake: TLS handshake timeout in seconds, default: 10 seconds
|
||||
# - idleConn: idle connection timeout in seconds, default: 90 seconds
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
keepalive: 0
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
# - optional, default: false
|
||||
# - unlisted models do not show up in /v1/models api requests
|
||||
# - can be requested as normal through all apis
|
||||
unlisted: true
|
||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker example:
|
||||
# container runtimes like Docker and Podman can be used reliably with
|
||||
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
cmd: |
|
||||
docker run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# cmdStop: command to run to stop the model gracefully
|
||||
# - optional, default: ""
|
||||
# - useful for stopping commands managed by another system
|
||||
# - the upstream's process id is available in the ${PID} macro
|
||||
#
|
||||
# When empty, llama-swap has this default behaviour:
|
||||
# - on POSIX systems: a SIGTERM signal is sent
|
||||
# - on Windows, calls taskkill to stop the process
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# =============================================================================
|
||||
# matrix: run concurrent models with a solver-based swap DSL
|
||||
# =============================================================================
|
||||
#
|
||||
# Note:
|
||||
# A config must use either a matrix or legacy groups, not both. A configuration error
|
||||
# will occur if both are defined. Configuration examples for legacy Groups can be found:
|
||||
# https://github.com/mostlygeek/llama-swap/blob/40e39f7/config.example.yaml#L334-L396
|
||||
#
|
||||
# The matrix declares valid combinations of models that can run concurrently.
|
||||
# When a model is requested, the solver finds the cheapest way to make it
|
||||
# available by evicting as few (and least costly) running models as possible.
|
||||
#
|
||||
# Solver behavior:
|
||||
# 1. Request arrives for model X
|
||||
# 2. If X is already running, forward immediately. Done.
|
||||
# 3. Find all sets containing X
|
||||
# 4. For each candidate set, compute cost: sum of evict_costs for
|
||||
# every running model NOT in that set
|
||||
# 5. Pick lowest cost candidate. Ties broken by definition order.
|
||||
# 6. Evict what needs to stop. Start X. Forward request.
|
||||
#
|
||||
# Subset semantics: a set [a, b, c] means any subset is valid.
|
||||
# Only the requested model is started — others are not preloaded.
|
||||
#
|
||||
# A model not appearing in any set can only run alone.
|
||||
#
|
||||
matrix:
|
||||
# vars: short names for models (alphanumeric, 1-8 chars)
|
||||
# - required for sets and evict_costs settings
|
||||
# - each entry is a short name to a real model ID. Do not use an alias
|
||||
# - used to keep set DSL logic short and easier to read
|
||||
# - sets and evict_costs only use identifiers defined in vars
|
||||
vars:
|
||||
g: gemma-model
|
||||
q: qwen-model
|
||||
m: mistral-model
|
||||
v: voxtral-model
|
||||
e: reranker-model
|
||||
L: llama-70B
|
||||
sd: stable-diffusion
|
||||
|
||||
# evict_costs: relative cost of losing a running model (default: 1)
|
||||
evict_costs:
|
||||
v: 50 # vllm backend, slow cold start
|
||||
L: 30 # 70B weights, slow to load
|
||||
|
||||
# sets: named sets of concurrent model combinations
|
||||
# Values are DSL strings with operators:
|
||||
# & AND (models run together)
|
||||
# | OR (alternatives)
|
||||
# () grouping
|
||||
# +ref inline another set's expression
|
||||
#
|
||||
# Expansion examples:
|
||||
# "L" → [L]
|
||||
# "a & b" → [a, b]
|
||||
# "a | b" → [a], [b]
|
||||
# "(a | b) & c" → [a, c], [b, c]
|
||||
# "(a | b) & (c | d)" → [a,c], [a,d], [b,c], [b,d]
|
||||
# "+llms & v" → expands llms inline, then applies & v
|
||||
sets:
|
||||
# LLM + TTS: switching between g/q/m won't evict v
|
||||
# expands to: [g,v], [q,v], [m,v]
|
||||
standard: "(g | q | m) & v"
|
||||
|
||||
# LLM + TTS + reranker
|
||||
# expands to: [g,v,e], [q,v,e]
|
||||
with_rerank: "(g | q) & v & e"
|
||||
|
||||
# LLM + image generation, no TTS
|
||||
# expands to: [g,sd], [q,sd]
|
||||
creative: "(g | q) & sd"
|
||||
|
||||
# 70B model uses all GPUs, can only run alone
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
hooks:
|
||||
# on_startup: a dictionary of actions to perform on startup
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported action is preload
|
||||
on_startup:
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
# - can be a string or a macro
|
||||
apiKey: ${env.OPENROUTER_API_KEY}
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
# timeouts: configure proxy connection timeouts for this peer
|
||||
# - optional, defaults shown below
|
||||
# - useful when the peer runs on slower hardware
|
||||
# - set any value to 0 to disable that timeout (not recommended)
|
||||
timeouts:
|
||||
connect: 30
|
||||
keepalive: 30
|
||||
responseHeader: 60
|
||||
tlsHandshake: 10
|
||||
idleConn: 90
|
||||
|
||||
# filters: a dictionary of filter settings for peer requests
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as model filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for removing parameters that the peer doesn't support
|
||||
# - the `model` parameter can never be removed
|
||||
stripParams: "temperature, top_p"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for injecting provider-specific settings like data retention policies
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce zero-data-retention for OpenRouter
|
||||
provider:
|
||||
data_collection: "deny"
|
||||
zdr: true
|
||||
```
|
||||
@@ -0,0 +1,9 @@
|
||||
## Container Security
|
||||
|
||||
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
|
||||
|
||||
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
|
||||
|
||||
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
|
||||
|
||||
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
|
||||
@@ -0,0 +1,153 @@
|
||||
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
|
||||
|
||||
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
|
||||
|
||||
## Here's what you you need:
|
||||
|
||||
- aider - [installation docs](https://aider.chat/docs/install.html)
|
||||
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
|
||||
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
|
||||
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
|
||||
- 24GB VRAM video card
|
||||
|
||||
## Running aider
|
||||
|
||||
The goal is getting this command line to work:
|
||||
|
||||
```sh
|
||||
aider --architect \
|
||||
--no-show-model-warnings \
|
||||
--model openai/QwQ \
|
||||
--editor-model openai/qwen-coder-32B \
|
||||
--model-settings-file aider.model.settings.yml \
|
||||
--openai-api-key "sk-na" \
|
||||
--openai-api-base "http://10.0.1.24:8080/v1" \
|
||||
```
|
||||
|
||||
Set `--openai-api-base` to the IP and port where your llama-swap is running.
|
||||
|
||||
## Create an aider model settings file
|
||||
|
||||
```yaml
|
||||
# aider.model.settings.yml
|
||||
|
||||
#
|
||||
# !!! important: model names must match llama-swap configuration names !!!
|
||||
#
|
||||
|
||||
- name: "openai/QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/qwen-coder-32B"
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
- name: "openai/qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
```
|
||||
|
||||
## llama-swap configuration
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||
--ctx-size 16000
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
-ngl 99
|
||||
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
|
||||
"QwQ":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
--ctx-size 32000
|
||||
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
|
||||
--min-p 0.01 --top-k 40 --top-p 0.95
|
||||
-ngl 99
|
||||
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||
```
|
||||
|
||||
## Advanced, Dual GPU Configuration
|
||||
|
||||
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
|
||||
|
||||
In llama-swap's configuration file:
|
||||
|
||||
1. add a `profiles` section with `aider` as the profile name
|
||||
2. using the `env` field to specify the GPU IDs for each model
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
# Add a profile for aider
|
||||
profiles:
|
||||
aider:
|
||||
- qwen-coder-32B
|
||||
- QwQ
|
||||
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
# manually set the GPU to run on
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: /path/to/llama-server ...
|
||||
|
||||
"QwQ":
|
||||
# manually set the GPU to run on
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=1"
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: /path/to/llama-server ...
|
||||
```
|
||||
|
||||
Append the profile tag, `aider:`, to the model names in the model settings file
|
||||
|
||||
```yaml
|
||||
# aider.model.settings.yml
|
||||
- name: "openai/aider:QwQ"
|
||||
weak_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
|
||||
- name: "openai/aider:qwen-coder-32B"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
```
|
||||
|
||||
Run aider with:
|
||||
|
||||
```sh
|
||||
$ aider --architect \
|
||||
--no-show-model-warnings \
|
||||
--model openai/aider:QwQ \
|
||||
--editor-model openai/aider:qwen-coder-32B \
|
||||
--config aider.conf.yml \
|
||||
--model-settings-file aider.model.settings.yml
|
||||
--openai-api-key "sk-na" \
|
||||
--openai-api-base "http://10.0.1.24:8080/v1"
|
||||
```
|
||||
@@ -0,0 +1,28 @@
|
||||
# this makes use of llama-swap's profile feature to
|
||||
# keep the architect and editor models in VRAM on different GPUs
|
||||
|
||||
- name: "openai/aider:QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/aider:qwen-coder-32B"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||
|
||||
- name: "openai/aider:qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||
@@ -0,0 +1,26 @@
|
||||
- name: "openai/QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/qwen-coder-32B"
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
- name: "openai/qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
healthCheckTimeout: 300
|
||||
logLevel: debug
|
||||
|
||||
profiles:
|
||||
aider:
|
||||
- qwen-coder-32B
|
||||
- QwQ
|
||||
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
aliases:
|
||||
- coder
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
|
||||
# set appropriate paths for your environment
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||
--ctx-size 16000
|
||||
--ctx-size-draft 16000
|
||||
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
|
||||
-ngl 99 -ngld 99
|
||||
--draft-max 16 --draft-min 4 --draft-p-min 0.4
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
"QwQ":
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=1"
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
|
||||
# set appropriate paths for your environment
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
--ctx-size 32000
|
||||
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||
--temp 0.6
|
||||
--repeat-penalty 1.1
|
||||
--dry-multiplier 0.5
|
||||
--min-p 0.01
|
||||
--top-k 40
|
||||
--top-p 0.95
|
||||
-ngl 99 -ngld 99
|
||||
@@ -0,0 +1,51 @@
|
||||
# Restart llama-swap on config change
|
||||
|
||||
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#
|
||||
# A simple watch and restart llama-swap when its configuration
|
||||
# file changes. Useful for trying out configuration changes
|
||||
# without manually restarting the server each time.
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <path to config.yaml>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
while true; do
|
||||
# Start the process again
|
||||
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
||||
PID=$!
|
||||
echo "Started llama-swap with PID $PID"
|
||||
|
||||
# Wait for modifications in the specified directory or file
|
||||
inotifywait -e modify "$1"
|
||||
|
||||
# Check if process exists before sending signal
|
||||
if kill -0 $PID 2>/dev/null; then
|
||||
echo "Sending SIGTERM to $PID"
|
||||
kill -SIGTERM $PID
|
||||
wait $PID
|
||||
else
|
||||
echo "Process $PID no longer exists"
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
```
|
||||
|
||||
## Usage and output example
|
||||
|
||||
```bash
|
||||
$ ./watch-and-restart.sh config.yaml
|
||||
Started llama-swap with PID 495455
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
Sending SIGTERM to 495455
|
||||
Shutting down llama-swap
|
||||
Started llama-swap with PID 495486
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
```
|
||||
@@ -0,0 +1,3 @@
|
||||
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
|
||||
|
||||
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
|
||||
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Default initializes a default in-process dispatcher
|
||||
var Default = NewDispatcherConfig(25000)
|
||||
|
||||
// On subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work. This
|
||||
// functions same way as Subscribe() but uses the default dispatcher instead.
|
||||
func On[T Event](handler func(T)) context.CancelFunc {
|
||||
return Subscribe(Default, handler)
|
||||
}
|
||||
|
||||
// OnType subscribes to an event with the specified event type. This functions
|
||||
// same way as SubscribeTo() but uses the default dispatcher instead.
|
||||
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
|
||||
return SubscribeTo(Default, eventType, handler)
|
||||
}
|
||||
|
||||
// Emit writes an event into the dispatcher. This functions same way as
|
||||
// Publish() but uses the default dispatcher instead.
|
||||
func Emit[T Event](ev T) {
|
||||
Publish(Default, ev)
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
/*
|
||||
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
|
||||
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
|
||||
*/
|
||||
func BenchmarkSubscribeConcurrent(b *testing.B) {
|
||||
d := NewDispatcher()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
unsub := Subscribe(d, func(ev MyEvent1) {})
|
||||
unsub()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultPublish(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe
|
||||
var count int64
|
||||
defer On(func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
defer OnType(TypeEvent1, func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(4)
|
||||
Emit(MyEvent1{})
|
||||
Emit(MyEvent1{})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(4), count)
|
||||
}
|
||||
+324
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for details.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Event represents an event contract
|
||||
type Event interface {
|
||||
Type() uint32
|
||||
}
|
||||
|
||||
// registry holds an immutable sorted array of event mappings
|
||||
type registry struct {
|
||||
keys []uint32 // Event types (sorted)
|
||||
grps []any // Corresponding subscribers
|
||||
}
|
||||
|
||||
// ------------------------------------- Dispatcher -------------------------------------
|
||||
|
||||
// Dispatcher represents an event dispatcher.
|
||||
type Dispatcher struct {
|
||||
subs atomic.Pointer[registry] // Atomic pointer to immutable array
|
||||
done chan struct{} // Cancellation
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
|
||||
}
|
||||
|
||||
// NewDispatcher creates a new dispatcher of events.
|
||||
func NewDispatcher() *Dispatcher {
|
||||
return NewDispatcherConfig(50000)
|
||||
}
|
||||
|
||||
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
|
||||
func NewDispatcherConfig(maxQueue int) *Dispatcher {
|
||||
d := &Dispatcher{
|
||||
done: make(chan struct{}),
|
||||
maxQueue: maxQueue,
|
||||
}
|
||||
|
||||
d.subs.Store(®istry{
|
||||
keys: make([]uint32, 0, 16),
|
||||
grps: make([]any, 0, 16),
|
||||
})
|
||||
return d
|
||||
}
|
||||
|
||||
// Close closes the dispatcher
|
||||
func (d *Dispatcher) Close() error {
|
||||
close(d.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isClosed returns whether the dispatcher is closed or not
|
||||
func (d *Dispatcher) isClosed() bool {
|
||||
select {
|
||||
case <-d.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// findGroup performs a lock-free binary search for the event type
|
||||
func (d *Dispatcher) findGroup(eventType uint32) any {
|
||||
reg := d.subs.Load()
|
||||
keys := reg.keys
|
||||
|
||||
// Inlined binary search for better cache locality
|
||||
left, right := 0, len(keys)
|
||||
for left < right {
|
||||
mid := left + (right-left)/2
|
||||
if keys[mid] < eventType {
|
||||
left = mid + 1
|
||||
} else {
|
||||
right = mid
|
||||
}
|
||||
}
|
||||
|
||||
if left < len(keys) && keys[left] == eventType {
|
||||
return reg.grps[left]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work.
|
||||
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
|
||||
var event T
|
||||
return SubscribeTo(broker, event.Type(), handler)
|
||||
}
|
||||
|
||||
// SubscribeTo subscribes to an event with the specified event type.
|
||||
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
|
||||
if broker.isClosed() {
|
||||
panic(errClosed)
|
||||
}
|
||||
|
||||
broker.mu.Lock()
|
||||
defer broker.mu.Unlock()
|
||||
|
||||
// Check if group already exists
|
||||
if existing := broker.findGroup(eventType); existing != nil {
|
||||
grp := groupOf[T](eventType, existing)
|
||||
sub := grp.Add(handler)
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new group
|
||||
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
|
||||
sub := grp.Add(handler)
|
||||
|
||||
// Copy-on-write: insert new entry in sorted position
|
||||
old := broker.subs.Load()
|
||||
idx := sort.Search(len(old.keys), func(i int) bool {
|
||||
return old.keys[i] >= eventType
|
||||
})
|
||||
|
||||
// Create new arrays with space for one more element
|
||||
newKeys := make([]uint32, len(old.keys)+1)
|
||||
newGrps := make([]any, len(old.grps)+1)
|
||||
|
||||
// Copy elements before insertion point
|
||||
copy(newKeys[:idx], old.keys[:idx])
|
||||
copy(newGrps[:idx], old.grps[:idx])
|
||||
|
||||
// Insert new element
|
||||
newKeys[idx] = eventType
|
||||
newGrps[idx] = grp
|
||||
|
||||
// Copy elements after insertion point
|
||||
copy(newKeys[idx+1:], old.keys[idx:])
|
||||
copy(newGrps[idx+1:], old.grps[idx:])
|
||||
|
||||
// Atomically store the new registry (mutex ensures no concurrent writers)
|
||||
newReg := ®istry{keys: newKeys, grps: newGrps}
|
||||
broker.subs.Store(newReg)
|
||||
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish writes an event into the dispatcher
|
||||
func Publish[T Event](broker *Dispatcher, ev T) {
|
||||
eventType := ev.Type()
|
||||
if sub := broker.findGroup(eventType); sub != nil {
|
||||
group := groupOf[T](eventType, sub)
|
||||
group.Broadcast(ev)
|
||||
}
|
||||
}
|
||||
|
||||
// Count counts the number of subscribers, this is for testing only.
|
||||
func (d *Dispatcher) count(eventType uint32) int {
|
||||
if group := d.findGroup(eventType); group != nil {
|
||||
return group.(interface{ Count() int }).Count()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// groupOf casts the subscriber group to the specified generic type
|
||||
func groupOf[T Event](eventType uint32, subs any) *group[T] {
|
||||
if group, ok := subs.(*group[T]); ok {
|
||||
return group
|
||||
}
|
||||
|
||||
panic(errConflict[T](eventType, subs))
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber -------------------------------------
|
||||
|
||||
// consumer represents a consumer with a message queue
|
||||
type consumer[T Event] struct {
|
||||
queue []T // Current work queue
|
||||
stop bool // Stop signal
|
||||
}
|
||||
|
||||
// Listen listens to the event queue and processes events
|
||||
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
|
||||
pending := make([]T, 0, 128)
|
||||
|
||||
for {
|
||||
c.L.Lock()
|
||||
for len(s.queue) == 0 {
|
||||
switch {
|
||||
case s.stop:
|
||||
c.L.Unlock()
|
||||
return
|
||||
default:
|
||||
c.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// Swap buffers and reset the current queue
|
||||
temp := s.queue
|
||||
s.queue = pending[:0]
|
||||
pending = temp
|
||||
c.L.Unlock()
|
||||
|
||||
// Outside of the critical section, process the work
|
||||
for _, event := range pending {
|
||||
fn(event)
|
||||
}
|
||||
|
||||
// Notify potential publishers waiting due to backpressure
|
||||
c.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber Group -------------------------------------
|
||||
|
||||
// group represents a consumer group
|
||||
type group[T Event] struct {
|
||||
cond *sync.Cond
|
||||
subs []*consumer[T]
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
maxLen int // Current maximum queue length across all consumers
|
||||
}
|
||||
|
||||
// Broadcast sends an event to all consumers
|
||||
func (s *group[T]) Broadcast(ev T) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Calculate current maximum queue length
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
|
||||
// Backpressure: wait if queues are full
|
||||
for s.maxLen >= s.maxQueue {
|
||||
s.cond.Wait()
|
||||
|
||||
// Recalculate after wakeup
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add event to all queues and track new maximum
|
||||
newMax := 0
|
||||
for _, sub := range s.subs {
|
||||
sub.queue = append(sub.queue, ev)
|
||||
if len(sub.queue) > newMax {
|
||||
newMax = len(sub.queue)
|
||||
}
|
||||
}
|
||||
s.maxLen = newMax
|
||||
s.cond.Broadcast() // Wake consumers
|
||||
}
|
||||
|
||||
// Add adds a subscriber to the list
|
||||
func (s *group[T]) Add(handler func(T)) *consumer[T] {
|
||||
sub := &consumer[T]{
|
||||
queue: make([]T, 0, 64),
|
||||
}
|
||||
|
||||
// Add the consumer to the list of active consumers
|
||||
s.cond.L.Lock()
|
||||
s.subs = append(s.subs, sub)
|
||||
s.cond.L.Unlock()
|
||||
|
||||
// Start listening
|
||||
go sub.Listen(s.cond, handler)
|
||||
return sub
|
||||
}
|
||||
|
||||
// Del removes a subscriber from the list
|
||||
func (s *group[T]) Del(sub *consumer[T]) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Search and remove the subscriber
|
||||
sub.stop = true
|
||||
for i, v := range s.subs {
|
||||
if v == sub {
|
||||
copy(s.subs[i:], s.subs[i+1:])
|
||||
s.subs = s.subs[:len(s.subs)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Debugging -------------------------------------
|
||||
|
||||
var errClosed = fmt.Errorf("event dispatcher is closed")
|
||||
|
||||
// Count returns the number of subscribers in this group
|
||||
func (s *group[T]) Count() int {
|
||||
return len(s.subs)
|
||||
}
|
||||
|
||||
// String returns string representation of the type
|
||||
func (s *group[T]) String() string {
|
||||
typ := reflect.TypeOf(s).String()
|
||||
idx := strings.LastIndex(typ, "/")
|
||||
typ = typ[idx+1 : len(typ)-1]
|
||||
return typ
|
||||
}
|
||||
|
||||
// errConflict returns a conflict message
|
||||
func errConflict[T any](eventType uint32, existing any) string {
|
||||
var want T
|
||||
return fmt.Sprintf(
|
||||
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
|
||||
want, existing, eventType,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe, must be received in order
|
||||
var count int64
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(3)
|
||||
Publish(d, MyEvent1{Number: 1})
|
||||
Publish(d, MyEvent1{Number: 2})
|
||||
Publish(d, MyEvent1{Number: 3})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
unsubscribe := Subscribe(d, func(ev MyEvent1) {
|
||||
// Nothing
|
||||
})
|
||||
|
||||
assert.Equal(t, 1, d.count(TypeEvent1))
|
||||
unsubscribe()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
}
|
||||
|
||||
func TestConcurrent(t *testing.T) {
|
||||
const max = 1000000
|
||||
var count int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
d := NewDispatcher()
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
if current := atomic.AddInt64(&count, 1); current == max {
|
||||
wg.Done()
|
||||
}
|
||||
})()
|
||||
|
||||
// Asynchronously publish
|
||||
go func() {
|
||||
for i := 0; i < max; i++ {
|
||||
Publish(d, MyEvent1{})
|
||||
}
|
||||
}()
|
||||
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
// Subscriber that does nothing
|
||||
})()
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, max, int(count))
|
||||
}
|
||||
|
||||
func TestSubscribeDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPublishDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
Publish(d, MyEvent1{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseDispatcher(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
|
||||
|
||||
assert.NoError(t, d.Close())
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatrix(t *testing.T) {
|
||||
const amount = 1000
|
||||
for _, subs := range []int{1, 10, 100} {
|
||||
for _, topics := range []int{1, 10} {
|
||||
expected := subs * topics * amount
|
||||
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(expected)
|
||||
|
||||
d := NewDispatcher()
|
||||
for i := 0; i < subs; i++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
|
||||
count.Add(1)
|
||||
wg.Done()
|
||||
})()
|
||||
}
|
||||
}
|
||||
|
||||
for n := 0; n < amount; n++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
go Publish(d, MyEvent3{ID: id})
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, expected, int(count.Load()))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentSubscriptionRace(t *testing.T) {
|
||||
// This test specifically targets the race condition that occurs when multiple
|
||||
// goroutines try to subscribe to different event types simultaneously.
|
||||
// Without the CAS loop, subscriptions could be lost due to registry corruption.
|
||||
|
||||
const numGoroutines = 100
|
||||
const numEventTypes = 50
|
||||
|
||||
d := NewDispatcher()
|
||||
defer d.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var receivedCount int64
|
||||
var subscribedTypes sync.Map // Thread-safe map
|
||||
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Start multiple goroutines that subscribe to different event types concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine subscribes to a unique event type
|
||||
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
|
||||
|
||||
// Subscribe to the event type
|
||||
SubscribeTo(d, eventType, func(ev MyEvent3) {
|
||||
atomic.AddInt64(&receivedCount, 1)
|
||||
})
|
||||
|
||||
// Record that this type was subscribed
|
||||
subscribedTypes.Store(eventType, true)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all subscriptions to complete
|
||||
wg.Wait()
|
||||
|
||||
// Count the number of unique event types subscribed
|
||||
expectedTypes := 0
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
expectedTypes++
|
||||
return true
|
||||
})
|
||||
|
||||
// Small delay to ensure all subscriptions are fully processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Publish events to each subscribed type
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
eventType := key.(uint32)
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
return true
|
||||
})
|
||||
|
||||
// Wait for all events to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify that we received at least the expected number of events
|
||||
// (there might be more if multiple goroutines subscribed to the same event type)
|
||||
received := atomic.LoadInt64(&receivedCount)
|
||||
assert.GreaterOrEqual(t, int(received), expectedTypes,
|
||||
"Should have received at least %d events, got %d", expectedTypes, received)
|
||||
|
||||
// Verify that we have the expected number of unique event types
|
||||
assert.Equal(t, numEventTypes, expectedTypes,
|
||||
"Should have exactly %d unique event types", numEventTypes)
|
||||
}
|
||||
|
||||
func TestConcurrentHandlerRegistration(t *testing.T) {
|
||||
const numGoroutines = 100
|
||||
|
||||
// Test concurrent subscriptions to the same event type
|
||||
t.Run("SameEventType", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var handlerCount int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple goroutines subscribing to the same event type (0x1)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
|
||||
atomic.AddInt64(&handlerCount, 1)
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all handlers were registered by publishing an event
|
||||
atomic.StoreInt64(&handlerCount, 0)
|
||||
Publish(d, MyEvent1{})
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
|
||||
"Not all handlers were registered due to race condition")
|
||||
})
|
||||
|
||||
// Test concurrent subscriptions to different event types
|
||||
t.Run("DifferentEventTypes", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
receivedEvents := make(map[uint32]*int64)
|
||||
|
||||
// Create multiple event types and subscribe concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
eventType := uint32(100 + i)
|
||||
counter := new(int64)
|
||||
receivedEvents[eventType] = counter
|
||||
|
||||
wg.Add(1)
|
||||
go func(et uint32, cnt *int64) {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, et, func(ev MyEvent3) {
|
||||
atomic.AddInt64(cnt, 1)
|
||||
})
|
||||
}(eventType, counter)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Publish events to all types
|
||||
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
}
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify all event types received their events
|
||||
for eventType, counter := range receivedEvents {
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
|
||||
"Event type %d did not receive its event", eventType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackpressure(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
d.maxQueue = 10
|
||||
|
||||
var processedCount int64
|
||||
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
|
||||
atomic.AddInt64(&processedCount, 1)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
const eventsToPublish = 1000
|
||||
for i := 0; i < eventsToPublish; i++ {
|
||||
Publish(d, MyEvent3{ID: 0x200})
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all events were eventually processed
|
||||
finalProcessed := atomic.LoadInt64(&processedCount)
|
||||
assert.Equal(t, int64(eventsToPublish), finalProcessed)
|
||||
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
|
||||
}
|
||||
|
||||
// ------------------------------------- Test Events -------------------------------------
|
||||
|
||||
const (
|
||||
TypeEvent1 = 0x1
|
||||
TypeEvent2 = 0x2
|
||||
)
|
||||
|
||||
type MyEvent1 struct {
|
||||
Number int
|
||||
}
|
||||
|
||||
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
|
||||
|
||||
type MyEvent2 struct {
|
||||
Text string
|
||||
}
|
||||
|
||||
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
|
||||
|
||||
type MyEvent3 struct {
|
||||
ID int
|
||||
}
|
||||
|
||||
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
|
||||
@@ -1,9 +1,14 @@
|
||||
module github.com/mostlygeek/llama-swap
|
||||
|
||||
go 1.23.0
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/klauspost/compress v1.18.5
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -15,12 +20,10 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
@@ -29,12 +32,14 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
@@ -15,6 +17,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
@@ -23,11 +27,13 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
@@ -57,6 +63,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -64,22 +80,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
+188
-10
@@ -1,47 +1,225 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/mostlygeek/llama-swap/proxy/configwatcher"
|
||||
)
|
||||
|
||||
var version string = "0"
|
||||
var commit string = "abcd1234"
|
||||
var date = "unknown"
|
||||
var (
|
||||
version string = "0"
|
||||
commit string = "abcd1234"
|
||||
date string = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||
listenStr := flag.String("listen", "", "listen ip/port")
|
||||
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("version: v%s (%s), built at %s\n", version, commit, date)
|
||||
fmt.Printf("version: %s (%s), built at %s\n", version, commit, date)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
config, err := proxy.LoadConfig(*configPath)
|
||||
conf, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(conf.Profiles) > 0 {
|
||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := proxyManager.Run(*listenStr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
// Validate TLS flags.
|
||||
var useTLS = (*certFile != "" && *keyFile != "")
|
||||
if (*certFile != "" && *keyFile == "") ||
|
||||
(*certFile == "" && *keyFile != "") {
|
||||
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default ports.
|
||||
if *listenStr == "" {
|
||||
defaultPort := ":8080"
|
||||
if useTLS {
|
||||
defaultPort = ":8443"
|
||||
}
|
||||
listenStr = &defaultPort
|
||||
}
|
||||
|
||||
// Setup channels for server management
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Reload signals (SIGHUP on POSIX, none on Windows — Windows does not
|
||||
// deliver SIGHUP). Always wired up so `kill -HUP` works regardless of
|
||||
// --watch-config.
|
||||
reloadChan := make(chan os.Signal, 1)
|
||||
if runtime.GOOS != "windows" {
|
||||
signal.Notify(reloadChan, syscall.SIGHUP)
|
||||
}
|
||||
|
||||
// Context that bounds the lifetime of background watcher goroutines.
|
||||
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create server with initial handler
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
}
|
||||
|
||||
// Support for watching config and reloading when it changes
|
||||
reloadProxyManager := func() {
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
srv.Handler = newPM
|
||||
fmt.Println("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateEnd,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
srv.Handler = newPM
|
||||
}
|
||||
}
|
||||
|
||||
// load the initial proxy manager
|
||||
reloadProxyManager()
|
||||
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||
|
||||
// Listen for ConfigFileChangedEvent unconditionally so SIGHUP and the
|
||||
// poll-based watcher both feed the same debounced reload pipeline. The
|
||||
// UI also listens for the matching ReloadingStateEnd emitted from
|
||||
// reloadProxyManager.
|
||||
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||
debouncedReload()
|
||||
}
|
||||
})()
|
||||
|
||||
// SIGHUP (or platform-equivalent) → reload. Back-to-back signals collapse
|
||||
// to one reload via the debounce window, which is the desired behavior.
|
||||
go func() {
|
||||
for range reloadChan {
|
||||
fmt.Println("Received reload signal, reloading configuration")
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
if *watchConfig {
|
||||
go func() {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Println("Watching configuration for changes (poll-based, 2s interval)")
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: func() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
},
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
|
||||
// shutdown on signal
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||
watcherCancel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
pm.Shutdown()
|
||||
} else {
|
||||
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
fmt.Printf("Server shutdown error: %v\n", err)
|
||||
}
|
||||
close(exitChan)
|
||||
}()
|
||||
|
||||
// Start server
|
||||
go func() {
|
||||
var err error
|
||||
if useTLS {
|
||||
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
|
||||
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||
} else {
|
||||
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Fatal server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
|
||||
func debounce(interval time.Duration, f func()) func() {
|
||||
var timer *time.Timer
|
||||
return func() {
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.AfterFunc(interval, f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
|
||||
silent := flag.Bool("silent", false, "disable all logging")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
// Create a new Gin router
|
||||
r := gin.New()
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
if echo == "" {
|
||||
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
}
|
||||
|
||||
// Parse the duration
|
||||
if delay == "" {
|
||||
delay = "100ms"
|
||||
}
|
||||
|
||||
t, err := time.ParseDuration(delay)
|
||||
if err != nil {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/plain")
|
||||
for _, char := range echo {
|
||||
c.Writer.Write([]byte(string(char)))
|
||||
c.Writer.Flush()
|
||||
|
||||
// wait
|
||||
<-time.After(t)
|
||||
}
|
||||
})
|
||||
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/env", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
|
||||
// Get environment variables
|
||||
envVars := os.Environ()
|
||||
|
||||
// Write each environment variable to the response
|
||||
for _, envVar := range envVars {
|
||||
c.String(200, envVar)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up the /health endpoint handler function
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||
})
|
||||
|
||||
address := "127.0.0.1:" + *port // Address with the specified port
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: address,
|
||||
Handler: r.Handler(),
|
||||
}
|
||||
|
||||
// Disable logging if the --silent flag is set
|
||||
if *silent {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("simple-responder listening on %s\n", address)
|
||||
// service connections
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("simple-responder err: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown the server with
|
||||
// a timeout of 5 seconds.
|
||||
quit := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("simple-responder shutting down")
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
ui_dist/*
|
||||
@@ -1,96 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/google/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config Config
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
// Remove trailing backslashes
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||
|
||||
// Split the command into arguments
|
||||
args, err := shlex.Split(cmdStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
@@ -0,0 +1,820 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
const (
|
||||
LogToStdoutProxy = "proxy"
|
||||
LogToStdoutUpstream = "upstream"
|
||||
LogToStdoutBoth = "both"
|
||||
LogToStdoutNone = "none"
|
||||
)
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
type MacroList []MacroEntry
|
||||
|
||||
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||
if value.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("macros must be a mapping")
|
||||
}
|
||||
|
||||
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||
for i := 0; i < len(value.Content); i += 2 {
|
||||
keyNode := value.Content[i]
|
||||
valueNode := value.Content[i+1]
|
||||
|
||||
var name string
|
||||
if err := keyNode.Decode(&name); err != nil {
|
||||
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||
}
|
||||
|
||||
var val any
|
||||
if err := valueNode.Decode(&val); err != nil {
|
||||
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||
}
|
||||
|
||||
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||
}
|
||||
|
||||
*ml = entries
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a macro value by name
|
||||
func (ml MacroList) Get(name string) (any, bool) {
|
||||
for _, entry := range ml {
|
||||
if entry.Name == name {
|
||||
return entry.Value, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||
func (ml MacroList) ToMap() map[string]any {
|
||||
result := make(map[string]any, len(ml))
|
||||
for _, entry := range ml {
|
||||
result[entry.Name] = entry.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||
}
|
||||
|
||||
type HookOnStartup struct {
|
||||
Preload []string `yaml:"preload"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||
LogToStdout string `yaml:"logToStdout"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
CaptureBuffer int `yaml:"captureBuffer"`
|
||||
GlobalTTL int `yaml:"globalTTL"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// swap matrix: solver-based alternative to groups
|
||||
Matrix *MatrixConfig `yaml:"matrix"`
|
||||
|
||||
// populated during validation when matrix is configured
|
||||
ExpandedSets []ExpandedSet `yaml:"-"`
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
|
||||
// send loading state in reasoning
|
||||
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
|
||||
// support API keys, see issue #433, #50, #251
|
||||
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if len(v) >= 1024 {
|
||||
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||
}
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@@ -0,0 +1,266 @@
|
||||
//go:build !windows
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
# comment 1
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
|
||||
// Test the default values are automatically set for global, model and group configurations
|
||||
// after loading the configuration
|
||||
func TestConfig_DefaultValuesPosix(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
assert.True(t, exists, "default group should exist")
|
||||
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||
assert.Equal(t, true, defaultGroup.Swap)
|
||||
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||
assert.Equal(t, false, defaultGroup.Persistent)
|
||||
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||
}
|
||||
|
||||
model1, exists := config.Models["model1"]
|
||||
assert.True(t, exists, "model1 should exist")
|
||||
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||
assert.Equal(t, "", model1.CmdStop)
|
||||
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||
assert.Equal(t, []string{}, model1.Aliases)
|
||||
assert.Equal(t, []string{}, model1.Env)
|
||||
assert.Equal(t, 0, model1.UnloadAfter)
|
||||
assert.Equal(t, false, model1.Unlisted)
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
|
||||
// default empty filter exists
|
||||
assert.Equal(t, "", model1.Filters.StripParams)
|
||||
}
|
||||
|
||||
func TestConfig_LoadPosix(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
hooks:
|
||||
on_startup:
|
||||
preload: ["model1", "model2"]
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
name: "Model 1"
|
||||
description: "This is model 1"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
defaultTimeout := TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
Hooks: HooksConfig{
|
||||
OnStartup: HookOnStartup{
|
||||
Preload: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,255 @@
|
||||
//go:build windows
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// does not support single quoted strings like in config_posix_test.go
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
|
||||
-a "double quotes" \
|
||||
-s
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
|
||||
func TestConfig_DefaultValuesWindows(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
assert.True(t, exists, "default group should exist")
|
||||
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||
assert.Equal(t, true, defaultGroup.Swap)
|
||||
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||
assert.Equal(t, false, defaultGroup.Persistent)
|
||||
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||
}
|
||||
|
||||
model1, exists := config.Models["model1"]
|
||||
assert.True(t, exists, "model1 should exist")
|
||||
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||
assert.Equal(t, "taskkill /f /t /pid ${PID}", model1.CmdStop)
|
||||
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||
assert.Equal(t, []string{}, model1.Aliases)
|
||||
assert.Equal(t, []string{}, model1.Env)
|
||||
assert.Equal(t, 0, model1.UnloadAfter)
|
||||
assert.Equal(t, false, model1.Unlisted)
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
|
||||
// default empty filter exists
|
||||
assert.Equal(t, "", model1.Filters.StripParams)
|
||||
}
|
||||
|
||||
func TestConfig_LoadWindows(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
defaultTimeout := TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
}
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
Timeouts: defaultTimeout,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||
var ProtectedParams = []string{"model"}
|
||||
|
||||
// Filters contains filter settings for modifying request parameters
|
||||
// Used by both models and peers
|
||||
type Filters struct {
|
||||
// StripParams is a comma-separated list of parameters to remove from requests
|
||||
// The "model" parameter can never be removed
|
||||
StripParams string `yaml:"stripParams"`
|
||||
|
||||
// SetParams is a dictionary of parameters to set/override in requests
|
||||
// Protected params (like "model") cannot be set
|
||||
SetParams map[string]any `yaml:"setParams"`
|
||||
|
||||
// SetParamsByID maps requested model IDs to parameters to set/override in requests.
|
||||
// Useful with aliases: a single loaded model can behave differently depending on
|
||||
// which alias the client used. Applied after SetParams, so it can override those values.
|
||||
// Protected params (like "model") cannot be set.
|
||||
SetParamsByID map[string]map[string]any `yaml:"setParamsByID"`
|
||||
}
|
||||
|
||||
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||
// with duplicates, empty strings, and protected params removed
|
||||
func (f Filters) SanitizedStripParams() []string {
|
||||
if f.StripParams == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
// Skip protected params, empty strings, and duplicates
|
||||
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slices.Sort(cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// SanitizedSetParamsByID returns the params to set for the given requestedModelID,
|
||||
// with protected params removed and keys sorted for consistent iteration order.
|
||||
// Returns nil if the ID has no entry or all its params are protected.
|
||||
func (f Filters) SanitizedSetParamsByID(requestedModelID string) (map[string]any, []string) {
|
||||
if len(f.SetParamsByID) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
params, found := f.SetParamsByID[requestedModelID]
|
||||
if !found || len(params) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
result := make(map[string]any, len(params))
|
||||
keys := make([]string, 0, len(params))
|
||||
for key, value := range params {
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return result, keys
|
||||
}
|
||||
|
||||
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||
// and keys sorted for consistent iteration order
|
||||
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||
if len(f.SetParams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any, len(f.SetParams))
|
||||
keys := make([]string, 0, len(f.SetParams))
|
||||
|
||||
for key, value := range f.SetParams {
|
||||
// Skip protected params
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Sort keys for consistent ordering
|
||||
sort.Strings(keys)
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, keys
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripParams string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
stripParams: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single param",
|
||||
stripParams: "temperature",
|
||||
want: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
stripParams: "temperature, top_p, top_k",
|
||||
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||
},
|
||||
{
|
||||
name: "model param filtered",
|
||||
stripParams: "model, temperature, top_p",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "only model param",
|
||||
stripParams: "model",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "duplicates removed",
|
||||
stripParams: "temperature, top_p, temperature",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace",
|
||||
stripParams: " temperature , top_p ",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "empty values filtered",
|
||||
stripParams: "temperature,,top_p,",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{StripParams: tt.stripParams}
|
||||
got := f.SanitizedStripParams()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParams map[string]any
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty setParams",
|
||||
setParams: nil,
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
setParams: map[string]any{},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "normal params",
|
||||
setParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected model param filtered",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "complex nested values",
|
||||
setParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantKeys: []string{"provider", "transforms"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParams: tt.setParams}
|
||||
gotParams, gotKeys := f.SanitizedSetParams()
|
||||
|
||||
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||
for i, key := range gotKeys {
|
||||
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||
}
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams, "expected nil params")
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||
for key, wantValue := range tt.wantParams {
|
||||
gotValue, exists := gotParams[key]
|
||||
assert.True(t, exists, "missing key: %s", key)
|
||||
// Simple comparison for basic types
|
||||
switch v := wantValue.(type) {
|
||||
case string, int, float64, bool:
|
||||
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParamsByID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParamsByID map[string]map[string]any
|
||||
requestedModelID string
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty SetParamsByID returns nil",
|
||||
setParamsByID: nil,
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map returns nil",
|
||||
setParamsByID: map[string]map[string]any{},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "non-matching model ID returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model2": {"temperature": 0.9},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "matching model ID returns correct params",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {"temperature": 0.7, "top_p": 0.9},
|
||||
"model2": {"temperature": 0.5},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected param model is filtered out",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "keys are sorted",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
wantKeys: []string{"a_param", "m_param", "z_param"},
|
||||
},
|
||||
{
|
||||
name: "alias style key lookup",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1:high": {"reasoning_effort": "high"},
|
||||
"model1:low": {"reasoning_effort": "low"},
|
||||
},
|
||||
requestedModelID: "model1:high",
|
||||
wantParams: map[string]any{
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
wantKeys: []string{"reasoning_effort"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParamsByID: tt.setParamsByID}
|
||||
gotParams, gotKeys := f.SanitizedSetParamsByID(tt.requestedModelID)
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams)
|
||||
assert.Nil(t, gotKeys)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantKeys, gotKeys)
|
||||
assert.Equal(t, tt.wantParams, gotParams)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedParams(t *testing.T) {
|
||||
// Verify that "model" is protected
|
||||
assert.Contains(t, ProtectedParams, "model")
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test macro-in-macro basic substitution
|
||||
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${B}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test LIFO substitution order with 3+ macro levels
|
||||
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: load ${full}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test MODEL_ID in global macro used by model
|
||||
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: ${podman-llama} -m model.gguf
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||
}
|
||||
|
||||
// Test model macro overrides global macro in substitution
|
||||
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: echo ${msg}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test self-reference detection error
|
||||
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${recursive}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "recursive")
|
||||
assert.Contains(t, err.Error(), "self-reference")
|
||||
}
|
||||
|
||||
// Test macro substitution in name and description fields
|
||||
func TestConfig_MacroInNameAndDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"VARIANT": "Q4_K_M"
|
||||
"FAMILY": "llama"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "${FAMILY} ${VARIANT}"
|
||||
description: "A ${FAMILY} model in ${VARIANT} format"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "llama Q4_K_M", config.Models["my-model"].Name)
|
||||
assert.Equal(t, "A llama model in Q4_K_M format", config.Models["my-model"].Description)
|
||||
}
|
||||
|
||||
// Test MODEL_ID macro in name and description fields
|
||||
func TestConfig_ModelIDInNameAndDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
llama-3b:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "Model: ${MODEL_ID}"
|
||||
description: "Running ${MODEL_ID}"
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Model: llama-3b", config.Models["llama-3b"].Name)
|
||||
assert.Equal(t, "Running llama-3b", config.Models["llama-3b"].Description)
|
||||
}
|
||||
|
||||
// Test unknown macro in name or description returns an error
|
||||
func TestConfig_UnknownMacroInNameDescription(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
models:
|
||||
test:
|
||||
cmd: echo ok
|
||||
proxy: http://localhost:8080
|
||||
name: "Model ${UNDEFINED}"
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
|
||||
// Test undefined macro reference error
|
||||
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${A}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var varKeyPattern = regexp.MustCompile(`^[a-zA-Z0-9]{1,8}$`)
|
||||
|
||||
// MatrixConfig represents the swap matrix configuration block.
|
||||
type MatrixConfig struct {
|
||||
Var map[string]string `yaml:"vars"`
|
||||
EvictCosts map[string]int `yaml:"evict_costs"`
|
||||
Sets OrderedSets `yaml:"sets"`
|
||||
}
|
||||
|
||||
// SetEntry is a single named set with its DSL expression.
|
||||
type SetEntry struct {
|
||||
Name string
|
||||
DSL string
|
||||
}
|
||||
|
||||
// OrderedSets preserves YAML definition order of sets (used for tie-breaking).
|
||||
type OrderedSets []SetEntry
|
||||
|
||||
func (os *OrderedSets) UnmarshalYAML(value *yaml.Node) error {
|
||||
if value.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("sets must be a mapping")
|
||||
}
|
||||
|
||||
entries := make([]SetEntry, 0, len(value.Content)/2)
|
||||
for i := 0; i < len(value.Content); i += 2 {
|
||||
keyNode := value.Content[i]
|
||||
valueNode := value.Content[i+1]
|
||||
|
||||
var name string
|
||||
if err := keyNode.Decode(&name); err != nil {
|
||||
return fmt.Errorf("failed to decode set name: %w", err)
|
||||
}
|
||||
|
||||
var dsl string
|
||||
if err := valueNode.Decode(&dsl); err != nil {
|
||||
return fmt.Errorf("failed to decode DSL for set %q: %w", name, err)
|
||||
}
|
||||
|
||||
entries = append(entries, SetEntry{Name: name, DSL: dsl})
|
||||
}
|
||||
|
||||
*os = entries
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExpandedSet is one valid combination of concurrent models (real model names).
|
||||
type ExpandedSet struct {
|
||||
SetName string
|
||||
DSL string
|
||||
Models []string // real model names, sorted
|
||||
}
|
||||
|
||||
// ValidateMatrix validates the matrix config and returns all expanded sets.
|
||||
func ValidateMatrix(matrix MatrixConfig, models map[string]ModelConfig) ([]ExpandedSet, error) {
|
||||
if len(matrix.Sets) == 0 {
|
||||
return nil, fmt.Errorf("matrix must define at least one set")
|
||||
}
|
||||
|
||||
if len(matrix.Var) == 0 {
|
||||
return nil, fmt.Errorf("matrix must define at least one var")
|
||||
}
|
||||
|
||||
// Validate var entries
|
||||
if matrix.Var != nil {
|
||||
for id, modelName := range matrix.Var {
|
||||
if !varKeyPattern.MatchString(id) {
|
||||
return nil, fmt.Errorf("var key %q must be alphanumeric and 1-8 characters", id)
|
||||
}
|
||||
if _, exists := models[modelName]; !exists {
|
||||
return nil, fmt.Errorf("var key %q references unknown model %q", id, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate evict_costs
|
||||
if matrix.EvictCosts != nil {
|
||||
for key, cost := range matrix.EvictCosts {
|
||||
if cost <= 0 {
|
||||
return nil, fmt.Errorf("evict_cost for %q must be a positive integer, got %d", key, cost)
|
||||
}
|
||||
if _, ok := matrix.Var[key]; !ok {
|
||||
return nil, fmt.Errorf("evict_costs: unknown var ID %q", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build dependency graph for +ref topological sort
|
||||
setNames := make(map[string]bool)
|
||||
for _, entry := range matrix.Sets {
|
||||
setNames[entry.Name] = true
|
||||
}
|
||||
|
||||
deps := make(map[string][]string) // setName -> set names it depends on
|
||||
for _, entry := range matrix.Sets {
|
||||
refs, err := extractRefs(entry.DSL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set %q: %w", entry.Name, err)
|
||||
}
|
||||
for _, ref := range refs {
|
||||
if !setNames[ref] {
|
||||
return nil, fmt.Errorf("set %q references undefined set %q", entry.Name, ref)
|
||||
}
|
||||
}
|
||||
deps[entry.Name] = refs
|
||||
}
|
||||
|
||||
// Topological sort with cycle detection
|
||||
order, err := topologicalSort(matrix.Sets, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Expand sets in topological order
|
||||
resolvedRefs := make(map[string][][]string) // set name -> expanded alias-level combos
|
||||
var allExpanded []ExpandedSet
|
||||
totalCombinations := 0
|
||||
|
||||
// Build ordered map for efficient lookup
|
||||
setDSL := make(map[string]string)
|
||||
for _, entry := range matrix.Sets {
|
||||
setDSL[entry.Name] = entry.DSL
|
||||
}
|
||||
|
||||
for _, name := range order {
|
||||
dsl := setDSL[name]
|
||||
combos, err := ParseAndExpandDSL(dsl, resolvedRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set %q: %w", name, err)
|
||||
}
|
||||
|
||||
resolvedRefs[name] = combos
|
||||
|
||||
// Resolve var IDs to real model names
|
||||
for _, combo := range combos {
|
||||
resolved := make([]string, len(combo))
|
||||
for i, ident := range combo {
|
||||
realName, ok := matrix.Var[ident]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("set %q: unknown var ID %q", name, ident)
|
||||
}
|
||||
resolved[i] = realName
|
||||
}
|
||||
sort.Strings(resolved)
|
||||
allExpanded = append(allExpanded, ExpandedSet{
|
||||
SetName: name,
|
||||
DSL: dsl,
|
||||
Models: resolved,
|
||||
})
|
||||
}
|
||||
|
||||
totalCombinations += len(combos)
|
||||
if totalCombinations > maxDSLExpansions {
|
||||
return nil, fmt.Errorf("total expanded combinations (%d) exceed limit of %d", totalCombinations, maxDSLExpansions)
|
||||
}
|
||||
}
|
||||
|
||||
return allExpanded, nil
|
||||
}
|
||||
|
||||
// topologicalSort returns set names in dependency order.
|
||||
// Returns an error if a cycle is detected.
|
||||
func topologicalSort(sets OrderedSets, deps map[string][]string) ([]string, error) {
|
||||
// States: 0 = unvisited, 1 = visiting, 2 = visited
|
||||
state := make(map[string]int)
|
||||
var order []string
|
||||
|
||||
var visit func(name string) error
|
||||
visit = func(name string) error {
|
||||
switch state[name] {
|
||||
case 1:
|
||||
return fmt.Errorf("circular reference detected involving set %q", name)
|
||||
case 2:
|
||||
return nil
|
||||
}
|
||||
state[name] = 1
|
||||
|
||||
for _, dep := range deps[name] {
|
||||
if err := visit(dep); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
state[name] = 2
|
||||
order = append(order, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Visit in definition order for deterministic output
|
||||
for _, entry := range sets {
|
||||
if state[entry.Name] == 0 {
|
||||
if err := visit(entry.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return order, nil
|
||||
}
|
||||
|
||||
// ResolvedEvictCosts returns a map of real model name -> evict cost,
|
||||
// resolving var IDs. Models not listed default to 1.
|
||||
func (m *MatrixConfig) ResolvedEvictCosts() map[string]int {
|
||||
costs := make(map[string]int)
|
||||
if m.EvictCosts == nil {
|
||||
return costs
|
||||
}
|
||||
for key, cost := range m.EvictCosts {
|
||||
// Resolve var ID if present
|
||||
if realName, ok := m.Var[key]; ok {
|
||||
costs[realName] = cost
|
||||
} else {
|
||||
costs[key] = cost
|
||||
}
|
||||
}
|
||||
return costs
|
||||
}
|
||||
@@ -0,0 +1,376 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const maxDSLExpansions = 1000
|
||||
|
||||
// Token types for the DSL lexer
|
||||
type tokenType int
|
||||
|
||||
const (
|
||||
tokIdent tokenType = iota // model alias or name
|
||||
tokAnd // &
|
||||
tokOr // |
|
||||
tokLParen // (
|
||||
tokRParen // )
|
||||
tokRef // +setName
|
||||
tokEOF
|
||||
)
|
||||
|
||||
type token struct {
|
||||
typ tokenType
|
||||
val string
|
||||
}
|
||||
|
||||
// tokenize splits a DSL string into tokens.
|
||||
func tokenize(input string) ([]token, error) {
|
||||
var tokens []token
|
||||
i := 0
|
||||
runes := []rune(input)
|
||||
|
||||
for i < len(runes) {
|
||||
ch := runes[i]
|
||||
|
||||
// skip whitespace
|
||||
if unicode.IsSpace(ch) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '&':
|
||||
tokens = append(tokens, token{tokAnd, "&"})
|
||||
i++
|
||||
case '|':
|
||||
tokens = append(tokens, token{tokOr, "|"})
|
||||
i++
|
||||
case '(':
|
||||
tokens = append(tokens, token{tokLParen, "("})
|
||||
i++
|
||||
case ')':
|
||||
tokens = append(tokens, token{tokRParen, ")"})
|
||||
i++
|
||||
case '+':
|
||||
// +ref: read the identifier that follows
|
||||
i++
|
||||
start := i
|
||||
for i < len(runes) && isIdentChar(runes[i]) {
|
||||
i++
|
||||
}
|
||||
if i == start {
|
||||
return nil, fmt.Errorf("expected set name after '+' at position %d", start)
|
||||
}
|
||||
tokens = append(tokens, token{tokRef, string(runes[start:i])})
|
||||
default:
|
||||
if isIdentChar(ch) {
|
||||
start := i
|
||||
for i < len(runes) && isIdentChar(runes[i]) {
|
||||
i++
|
||||
}
|
||||
tokens = append(tokens, token{tokIdent, string(runes[start:i])})
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected character %q at position %d", ch, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokens = append(tokens, token{tokEOF, ""})
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func isIdentChar(ch rune) bool {
|
||||
return unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '-' || ch == '.'
|
||||
}
|
||||
|
||||
// AST node types
|
||||
type dslNode interface {
|
||||
dslNode()
|
||||
}
|
||||
|
||||
type andNode struct {
|
||||
children []dslNode
|
||||
}
|
||||
|
||||
type orNode struct {
|
||||
children []dslNode
|
||||
}
|
||||
|
||||
type leafNode struct {
|
||||
name string
|
||||
}
|
||||
|
||||
type refNode struct {
|
||||
setName string
|
||||
}
|
||||
|
||||
func (andNode) dslNode() {}
|
||||
func (orNode) dslNode() {}
|
||||
func (leafNode) dslNode() {}
|
||||
func (refNode) dslNode() {}
|
||||
|
||||
// parser holds state for recursive-descent parsing.
|
||||
type parser struct {
|
||||
tokens []token
|
||||
pos int
|
||||
}
|
||||
|
||||
func (p *parser) peek() token {
|
||||
if p.pos < len(p.tokens) {
|
||||
return p.tokens[p.pos]
|
||||
}
|
||||
return token{tokEOF, ""}
|
||||
}
|
||||
|
||||
func (p *parser) next() token {
|
||||
t := p.peek()
|
||||
if t.typ != tokEOF {
|
||||
p.pos++
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (p *parser) expect(typ tokenType) (token, error) {
|
||||
t := p.next()
|
||||
if t.typ != typ {
|
||||
return t, fmt.Errorf("expected token type %d, got %q", typ, t.val)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Grammar:
|
||||
//
|
||||
// expr = andExpr
|
||||
// andExpr = orExpr ('&' orExpr)*
|
||||
// orExpr = atom ('|' atom)*
|
||||
// atom = ident | '+' ident | '(' expr ')'
|
||||
//
|
||||
// & binds tighter than |, so "a | b & c" means "a | (b & c)"
|
||||
func parse(tokens []token) (dslNode, error) {
|
||||
p := &parser{tokens: tokens}
|
||||
node, err := p.parseExpr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.peek().typ != tokEOF {
|
||||
return nil, fmt.Errorf("unexpected token %q after expression", p.peek().val)
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseExpr() (dslNode, error) {
|
||||
return p.parseOrExpr()
|
||||
}
|
||||
|
||||
func (p *parser) parseOrExpr() (dslNode, error) {
|
||||
left, err := p.parseAndExpr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.peek().typ == tokOr {
|
||||
children := []dslNode{left}
|
||||
for p.peek().typ == tokOr {
|
||||
p.next() // consume |
|
||||
right, err := p.parseAndExpr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
children = append(children, right)
|
||||
}
|
||||
return orNode{children: children}, nil
|
||||
}
|
||||
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseAndExpr() (dslNode, error) {
|
||||
left, err := p.parseAtom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.peek().typ == tokAnd {
|
||||
children := []dslNode{left}
|
||||
for p.peek().typ == tokAnd {
|
||||
p.next() // consume &
|
||||
right, err := p.parseAtom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
children = append(children, right)
|
||||
}
|
||||
return andNode{children: children}, nil
|
||||
}
|
||||
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseAtom() (dslNode, error) {
|
||||
t := p.peek()
|
||||
|
||||
switch t.typ {
|
||||
case tokIdent:
|
||||
p.next()
|
||||
return leafNode{name: t.val}, nil
|
||||
|
||||
case tokRef:
|
||||
p.next()
|
||||
return refNode{setName: t.val}, nil
|
||||
|
||||
case tokLParen:
|
||||
p.next() // consume (
|
||||
node, err := p.parseExpr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := p.expect(tokRParen); err != nil {
|
||||
return nil, fmt.Errorf("missing closing parenthesis")
|
||||
}
|
||||
return node, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected token %q", t.val)
|
||||
}
|
||||
}
|
||||
|
||||
// expand walks the AST and produces all combinations.
|
||||
// resolvedRefs contains previously expanded sets for +ref resolution.
|
||||
func expand(node dslNode, resolvedRefs map[string][][]string) ([][]string, error) {
|
||||
switch n := node.(type) {
|
||||
case leafNode:
|
||||
return [][]string{{n.name}}, nil
|
||||
|
||||
case refNode:
|
||||
expanded, ok := resolvedRefs[n.setName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown set reference +%s", n.setName)
|
||||
}
|
||||
// Return a copy
|
||||
result := make([][]string, len(expanded))
|
||||
for i, combo := range expanded {
|
||||
result[i] = make([]string, len(combo))
|
||||
copy(result[i], combo)
|
||||
}
|
||||
return result, nil
|
||||
|
||||
case orNode:
|
||||
// Union of all children's expansions
|
||||
var result [][]string
|
||||
for _, child := range n.children {
|
||||
childResult, err := expand(child, resolvedRefs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, childResult...)
|
||||
if len(result) > maxDSLExpansions {
|
||||
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", maxDSLExpansions)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
|
||||
case andNode:
|
||||
// Cartesian product across children
|
||||
result := [][]string{{}} // start with one empty combo
|
||||
for _, child := range n.children {
|
||||
childResult, err := expand(child, resolvedRefs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err = cartesianProduct(result, childResult, maxDSLExpansions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown node type %T", node)
|
||||
}
|
||||
}
|
||||
|
||||
// cartesianProduct computes the cartesian product of two sets of combinations.
|
||||
// It returns an error if the product would exceed cap.
|
||||
func cartesianProduct(left, right [][]string, cap int) ([][]string, error) {
|
||||
if int64(len(left))*int64(len(right)) > int64(cap) {
|
||||
return nil, fmt.Errorf("DSL expansion exceeded %d combinations", cap)
|
||||
}
|
||||
result := make([][]string, 0, len(left)*len(right))
|
||||
for _, l := range left {
|
||||
for _, r := range right {
|
||||
combo := make([]string, 0, len(l)+len(r))
|
||||
combo = append(combo, l...)
|
||||
combo = append(combo, r...)
|
||||
result = append(result, combo)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ParseAndExpandDSL tokenizes, parses, and expands a DSL string.
|
||||
// resolvedRefs contains previously expanded sets for +ref inlining.
|
||||
func ParseAndExpandDSL(dsl string, resolvedRefs map[string][][]string) ([][]string, error) {
|
||||
dsl = strings.TrimSpace(dsl)
|
||||
if dsl == "" {
|
||||
return nil, fmt.Errorf("empty DSL expression")
|
||||
}
|
||||
|
||||
tokens, err := tokenize(dsl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenize: %w", err)
|
||||
}
|
||||
|
||||
tree, err := parse(tokens)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse: %w", err)
|
||||
}
|
||||
|
||||
result, err := expand(tree, resolvedRefs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Deduplicate models within each combination and sort for consistency
|
||||
for i, combo := range result {
|
||||
result[i] = dedupAndSort(combo)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// dedupAndSort removes duplicate entries and sorts alphabetically.
|
||||
func dedupAndSort(items []string) []string {
|
||||
seen := make(map[string]bool, len(items))
|
||||
var unique []string
|
||||
for _, item := range items {
|
||||
if !seen[item] {
|
||||
seen[item] = true
|
||||
unique = append(unique, item)
|
||||
}
|
||||
}
|
||||
sort.Strings(unique)
|
||||
return unique
|
||||
}
|
||||
|
||||
// extractRefs scans a DSL string for +ref tokens without full parsing.
|
||||
// Used for building the dependency graph for topological sorting.
|
||||
func extractRefs(dsl string) ([]string, error) {
|
||||
tokens, err := tokenize(dsl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var refs []string
|
||||
seen := make(map[string]bool)
|
||||
for _, t := range tokens {
|
||||
if t.typ == tokRef && !seen[t.val] {
|
||||
seen[t.val] = true
|
||||
refs = append(refs, t.val)
|
||||
}
|
||||
}
|
||||
return refs, nil
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDSL_Tokenize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect []token
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "single identifier",
|
||||
input: "abc",
|
||||
expect: []token{
|
||||
{tokIdent, "abc"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "identifier with hyphens and dots",
|
||||
input: "model-name.v2",
|
||||
expect: []token{
|
||||
{tokIdent, "model-name.v2"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "and expression",
|
||||
input: "a & b",
|
||||
expect: []token{
|
||||
{tokIdent, "a"},
|
||||
{tokAnd, "&"},
|
||||
{tokIdent, "b"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "or expression",
|
||||
input: "a | b",
|
||||
expect: []token{
|
||||
{tokIdent, "a"},
|
||||
{tokOr, "|"},
|
||||
{tokIdent, "b"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "parentheses",
|
||||
input: "(a | b) & c",
|
||||
expect: []token{
|
||||
{tokLParen, "("},
|
||||
{tokIdent, "a"},
|
||||
{tokOr, "|"},
|
||||
{tokIdent, "b"},
|
||||
{tokRParen, ")"},
|
||||
{tokAnd, "&"},
|
||||
{tokIdent, "c"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ref token",
|
||||
input: "+llms & v",
|
||||
expect: []token{
|
||||
{tokRef, "llms"},
|
||||
{tokAnd, "&"},
|
||||
{tokIdent, "v"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no whitespace",
|
||||
input: "(a|b)&c",
|
||||
expect: []token{
|
||||
{tokLParen, "("},
|
||||
{tokIdent, "a"},
|
||||
{tokOr, "|"},
|
||||
{tokIdent, "b"},
|
||||
{tokRParen, ")"},
|
||||
{tokAnd, "&"},
|
||||
{tokIdent, "c"},
|
||||
{tokEOF, ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty ref",
|
||||
input: "+",
|
||||
errMsg: "expected set name after '+'",
|
||||
},
|
||||
{
|
||||
name: "invalid character",
|
||||
input: "a @ b",
|
||||
errMsg: "unexpected character",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokens, err := tokenize(tt.input)
|
||||
if tt.errMsg != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expect, tokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSL_ParseAndExpand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dsl string
|
||||
refs map[string][][]string
|
||||
expect [][]string
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "single model",
|
||||
dsl: "L",
|
||||
expect: [][]string{{"L"}},
|
||||
},
|
||||
{
|
||||
name: "two models with AND",
|
||||
dsl: "a & b",
|
||||
expect: [][]string{{"a", "b"}},
|
||||
},
|
||||
{
|
||||
name: "two models with OR",
|
||||
dsl: "a | b",
|
||||
expect: [][]string{{"a"}, {"b"}},
|
||||
},
|
||||
{
|
||||
name: "three models with OR",
|
||||
dsl: "a | b | c",
|
||||
expect: [][]string{{"a"}, {"b"}, {"c"}},
|
||||
},
|
||||
{
|
||||
name: "cartesian product (a|b) & (c|d)",
|
||||
dsl: "(a | b) & (c | d)",
|
||||
expect: [][]string{
|
||||
{"a", "c"},
|
||||
{"a", "d"},
|
||||
{"b", "c"},
|
||||
{"b", "d"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "three-way AND",
|
||||
dsl: "a & b & c",
|
||||
expect: [][]string{
|
||||
{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "(g | q | m) & v",
|
||||
dsl: "(g | q | m) & v",
|
||||
expect: [][]string{
|
||||
{"g", "v"},
|
||||
{"q", "v"},
|
||||
{"m", "v"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "(g | q) & v & e",
|
||||
dsl: "(g | q) & v & e",
|
||||
expect: [][]string{
|
||||
{"e", "g", "v"},
|
||||
{"e", "q", "v"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "precedence: a | b & c means a | (b & c)",
|
||||
dsl: "a | b & c",
|
||||
expect: [][]string{
|
||||
{"a"},
|
||||
{"b", "c"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "+ref inlining",
|
||||
dsl: "+llms & v",
|
||||
refs: map[string][][]string{
|
||||
"llms": {{"g"}, {"q"}, {"m"}},
|
||||
},
|
||||
expect: [][]string{
|
||||
{"g", "v"},
|
||||
{"q", "v"},
|
||||
{"m", "v"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "+ref chained",
|
||||
dsl: "+with_tts & e",
|
||||
refs: map[string][][]string{
|
||||
"with_tts": {{"g", "v"}, {"q", "v"}, {"m", "v"}},
|
||||
},
|
||||
expect: [][]string{
|
||||
{"e", "g", "v"},
|
||||
{"e", "q", "v"},
|
||||
{"e", "m", "v"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dedup within combination",
|
||||
dsl: "a & a",
|
||||
expect: [][]string{
|
||||
{"a"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty expression",
|
||||
dsl: "",
|
||||
errMsg: "empty DSL expression",
|
||||
},
|
||||
{
|
||||
name: "unmatched open paren",
|
||||
dsl: "(a | b",
|
||||
errMsg: "missing closing parenthesis",
|
||||
},
|
||||
{
|
||||
name: "unmatched close paren",
|
||||
dsl: "a | b)",
|
||||
errMsg: "unexpected token",
|
||||
},
|
||||
{
|
||||
name: "unknown ref",
|
||||
dsl: "+unknown",
|
||||
errMsg: "unknown set reference +unknown",
|
||||
},
|
||||
{
|
||||
name: "empty parens",
|
||||
dsl: "()",
|
||||
errMsg: "unexpected token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
refs := tt.refs
|
||||
if refs == nil {
|
||||
refs = map[string][][]string{}
|
||||
}
|
||||
result, err := ParseAndExpandDSL(tt.dsl, refs)
|
||||
if tt.errMsg != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expect, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSL_ExpansionCap(t *testing.T) {
|
||||
// Build an expression that would exceed 1000 combinations:
|
||||
// (a1|a2|...|a32) & (b1|b2|...|b32) = 1024 combos
|
||||
var aItems, bItems []string
|
||||
for i := 0; i < 32; i++ {
|
||||
aItems = append(aItems, fmt.Sprintf("a%d", i))
|
||||
bItems = append(bItems, fmt.Sprintf("b%d", i))
|
||||
}
|
||||
dsl := fmt.Sprintf("(%s) & (%s)",
|
||||
join(aItems, " | "),
|
||||
join(bItems, " | "),
|
||||
)
|
||||
_, err := ParseAndExpandDSL(dsl, map[string][][]string{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeded")
|
||||
}
|
||||
|
||||
func TestDSL_ExtractRefs(t *testing.T) {
|
||||
refs, err := extractRefs("+llms & v & +other")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"llms", "other"}, refs)
|
||||
|
||||
refs, err = extractRefs("a & b")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, refs)
|
||||
}
|
||||
|
||||
func join(items []string, sep string) string {
|
||||
result := ""
|
||||
for i, item := range items {
|
||||
if i > 0 {
|
||||
result += sep
|
||||
}
|
||||
result += item
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeModels(names ...string) map[string]ModelConfig {
|
||||
m := make(map[string]ModelConfig)
|
||||
for _, name := range names {
|
||||
m[name] = ModelConfig{Cmd: "echo " + name}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestValidateMatrix_Basic(t *testing.T) {
|
||||
models := makeModels("gemma", "qwen", "mistral", "voxtral", "llama70B")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{
|
||||
"g": "gemma",
|
||||
"q": "qwen",
|
||||
"m": "mistral",
|
||||
"v": "voxtral",
|
||||
"L": "llama70B",
|
||||
},
|
||||
EvictCosts: map[string]int{
|
||||
"L": 30,
|
||||
"v": 50,
|
||||
},
|
||||
Sets: OrderedSets{
|
||||
{Name: "standard", DSL: "(g | q | m) & v"},
|
||||
{Name: "full", DSL: "L"},
|
||||
},
|
||||
}
|
||||
|
||||
expanded, err := ValidateMatrix(matrix, models)
|
||||
require.NoError(t, err)
|
||||
|
||||
// standard expands to [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
|
||||
// full expands to [llama70B]
|
||||
assert.Len(t, expanded, 4)
|
||||
|
||||
assert.Equal(t, "standard", expanded[0].SetName)
|
||||
assert.Equal(t, []string{"gemma", "voxtral"}, expanded[0].Models)
|
||||
|
||||
assert.Equal(t, "standard", expanded[1].SetName)
|
||||
assert.Equal(t, []string{"qwen", "voxtral"}, expanded[1].Models)
|
||||
|
||||
assert.Equal(t, "standard", expanded[2].SetName)
|
||||
assert.Equal(t, []string{"mistral", "voxtral"}, expanded[2].Models)
|
||||
|
||||
assert.Equal(t, "full", expanded[3].SetName)
|
||||
assert.Equal(t, []string{"llama70B"}, expanded[3].Models)
|
||||
}
|
||||
|
||||
func TestValidateMatrix_WithRef(t *testing.T) {
|
||||
models := makeModels("gemma", "qwen", "mistral", "voxtral", "reranker")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{
|
||||
"g": "gemma",
|
||||
"q": "qwen",
|
||||
"m": "mistral",
|
||||
"v": "voxtral",
|
||||
"e": "reranker",
|
||||
},
|
||||
Sets: OrderedSets{
|
||||
{Name: "llms", DSL: "g | q | m"},
|
||||
{Name: "with_tts", DSL: "+llms & v"},
|
||||
{Name: "mega", DSL: "+with_tts & e"},
|
||||
},
|
||||
}
|
||||
|
||||
expanded, err := ValidateMatrix(matrix, models)
|
||||
require.NoError(t, err)
|
||||
|
||||
// llms: [gemma], [qwen], [mistral]
|
||||
// with_tts: [gemma,voxtral], [qwen,voxtral], [mistral,voxtral]
|
||||
// mega: [gemma,reranker,voxtral], [qwen,reranker,voxtral], [mistral,reranker,voxtral]
|
||||
assert.Len(t, expanded, 9)
|
||||
|
||||
// Check mega entries
|
||||
megaEntries := filterBySetName(expanded, "mega")
|
||||
assert.Len(t, megaEntries, 3)
|
||||
assert.Equal(t, []string{"gemma", "reranker", "voxtral"}, megaEntries[0].Models)
|
||||
}
|
||||
|
||||
func TestValidateMatrix_MapIDRequired(t *testing.T) {
|
||||
// DSL cannot use real model names directly — must use var IDs
|
||||
models := makeModels("gemma", "voxtral")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
Sets: OrderedSets{
|
||||
{Name: "combo", DSL: "g & voxtral"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown var ID")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_InvalidAliasKey(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
alias string
|
||||
errMsg string
|
||||
}{
|
||||
{"too long", "abcdefghi", "alphanumeric and 1-8 characters"},
|
||||
{"has underscore", "a_b", "alphanumeric and 1-8 characters"},
|
||||
{"has hyphen", "a-b", "alphanumeric and 1-8 characters"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{tt.alias: "gemma"},
|
||||
Sets: OrderedSets{{Name: "s", DSL: tt.alias}},
|
||||
}
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMatrix_AliasReferencesUnknownModel(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"x": "nonexistent"},
|
||||
Sets: OrderedSets{{Name: "s", DSL: "x"}},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown model")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_EvictCostInvalid(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
t.Run("zero cost", func(t *testing.T) {
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
EvictCosts: map[string]int{"g": 0},
|
||||
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||
}
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "positive integer")
|
||||
})
|
||||
|
||||
t.Run("negative cost", func(t *testing.T) {
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
EvictCosts: map[string]int{"g": -1},
|
||||
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||
}
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "positive integer")
|
||||
})
|
||||
|
||||
t.Run("unknown var ID in evict_costs", func(t *testing.T) {
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
EvictCosts: map[string]int{"unknown": 5},
|
||||
Sets: OrderedSets{{Name: "s", DSL: "g"}},
|
||||
}
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown var ID")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateMatrix_CycleDetection(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
Sets: OrderedSets{
|
||||
{Name: "a", DSL: "+b"},
|
||||
{Name: "b", DSL: "+a"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "circular reference")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_UndefinedRefTarget(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
Sets: OrderedSets{
|
||||
{Name: "a", DSL: "+nonexistent"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "references undefined set")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_NoSets(t *testing.T) {
|
||||
_, err := ValidateMatrix(MatrixConfig{}, makeModels("gemma"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "at least one set")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_UnknownMapIDInDSL(t *testing.T) {
|
||||
models := makeModels("gemma")
|
||||
|
||||
matrix := MatrixConfig{
|
||||
Var: map[string]string{"g": "gemma"},
|
||||
Sets: OrderedSets{
|
||||
{Name: "s", DSL: "g & nonexistent"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ValidateMatrix(matrix, models)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown var ID")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_ResolvedEvictCosts(t *testing.T) {
|
||||
mc := &MatrixConfig{
|
||||
Var: map[string]string{
|
||||
"g": "gemma",
|
||||
"L": "llama70B",
|
||||
},
|
||||
EvictCosts: map[string]int{
|
||||
"L": 30,
|
||||
"g": 5,
|
||||
},
|
||||
}
|
||||
|
||||
costs := mc.ResolvedEvictCosts()
|
||||
assert.Equal(t, 30, costs["llama70B"])
|
||||
assert.Equal(t, 5, costs["gemma"])
|
||||
}
|
||||
|
||||
func TestValidateMatrix_ConfigXOR(t *testing.T) {
|
||||
// groups and matrix both defined
|
||||
yaml := `
|
||||
models:
|
||||
model1:
|
||||
cmd: echo model1
|
||||
proxy: http://localhost:8080
|
||||
groups:
|
||||
group1:
|
||||
members:
|
||||
- model1
|
||||
matrix:
|
||||
sets:
|
||||
s: "model1"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot use both")
|
||||
}
|
||||
|
||||
func TestValidateMatrix_ConfigMatrixOnly(t *testing.T) {
|
||||
yaml := `
|
||||
models:
|
||||
gemma:
|
||||
cmd: echo gemma
|
||||
proxy: http://localhost:8080
|
||||
qwen:
|
||||
cmd: echo qwen
|
||||
proxy: http://localhost:8081
|
||||
matrix:
|
||||
vars:
|
||||
g: gemma
|
||||
q: qwen
|
||||
sets:
|
||||
combo: "g | q"
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(yaml))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cfg.Matrix)
|
||||
assert.Len(t, cfg.ExpandedSets, 2)
|
||||
// Groups should be empty when matrix is used
|
||||
assert.Empty(t, cfg.Groups)
|
||||
}
|
||||
|
||||
func filterBySetName(sets []ExpandedSet, name string) []ExpandedSet {
|
||||
var result []ExpandedSet
|
||||
for _, s := range sets {
|
||||
if s.SetName == name {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||
)
|
||||
|
||||
// TimeoutsConfig holds timeout settings for proxy connections
|
||||
// 0 = no timeout
|
||||
type TimeoutsConfig struct {
|
||||
Connect int `yaml:"connect"`
|
||||
KeepAlive int `yaml:"keepalive"`
|
||||
ResponseHeader int `yaml:"responseHeader"`
|
||||
TLSHandshake int `yaml:"tlsHandshake"`
|
||||
ExpectContinue int `yaml:"expectContinue"`
|
||||
IdleConn int `yaml:"idleConn"`
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
// Model filters see issue #174
|
||||
Filters ModelFilters `yaml:"filters"`
|
||||
|
||||
// Macros: see #264
|
||||
// Model level macros take precedence over the global macros
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// Metadata: see #264
|
||||
// Arbitrary metadata that can be exposed through the API
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
|
||||
// override global setting
|
||||
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelConfig ModelConfig
|
||||
defaults := rawModelConfig{
|
||||
Cmd: "",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:${PORT}",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
|
||||
// matches http.DefaultTransport
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 0,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" {
|
||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||
// See issue #174
|
||||
type ModelFilters struct {
|
||||
Filters `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal with the old field name for backwards compatibility
|
||||
if defaults.StripParams == "" {
|
||||
var legacy struct {
|
||||
StripParams string `yaml:"strip_params"`
|
||||
}
|
||||
if legacyErr := unmarshal(&legacy); legacyErr != nil {
|
||||
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
|
||||
}
|
||||
defaults.StripParams = legacy.StripParams
|
||||
}
|
||||
|
||||
*m = ModelFilters(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||
// Returns ([]string, error) to match existing API
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
return f.Filters.SanitizedStripParams(), nil
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_ModelFilters(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
default_strip: "temperature, top_p"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
# macros inserted and list is cleaned of duplicates and empty strings
|
||||
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
# check for strip_params (legacy field name) compatibility
|
||||
legacy:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
for modelId, modelConfig := range config.Models {
|
||||
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
|
||||
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||
if assert.NoError(t, err) {
|
||||
// model has been removed
|
||||
// empty strings have been removed
|
||||
// duplicates have been removed
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelSendLoadingState(t *testing.T) {
|
||||
content := `
|
||||
sendLoadingState: true
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
sendLoadingState: false
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, config.SendLoadingState)
|
||||
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
|
||||
assert.False(t, *config.Models["model1"].SendLoadingState)
|
||||
}
|
||||
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAlias(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"${MODEL_ID}:high":
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
reasoning_effort: low
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Keys (other than the model's own ID) should be registered as aliases
|
||||
realName, found := cfg.RealModelName("model1:high")
|
||||
assert.True(t, found, "model1:high should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
realName, found = cfg.RealModelName("model1:low")
|
||||
assert.True(t, found, "model1:low should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
// Auto-aliases should also appear in modelConfig.Aliases
|
||||
aliases := cfg.Models["model1"].Aliases
|
||||
assert.Contains(t, aliases, "model1:high")
|
||||
assert.Contains(t, aliases, "model1:low")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithModelID(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
model2:
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "conflicts with an existing model ID")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithOtherModel(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: low
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "duplicate alias")
|
||||
}
|
||||
|
||||
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
stripParams: "top_k"
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelConfig := config.Models["model1"]
|
||||
|
||||
// Check stripParams
|
||||
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||
|
||||
// Check setParams
|
||||
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||
assert.NotNil(t, setParams)
|
||||
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type PeerDictionaryConfig map[string]PeerConfig
|
||||
type PeerConfig struct {
|
||||
Proxy string `yaml:"proxy"`
|
||||
ProxyURL *url.URL `yaml:"-"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Models []string `yaml:"models"`
|
||||
Filters Filters `yaml:"filters"`
|
||||
|
||||
// Timeout settings for proxy connections
|
||||
Timeouts TimeoutsConfig `yaml:"timeouts"`
|
||||
}
|
||||
|
||||
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawPeerConfig PeerConfig
|
||||
defaults := rawPeerConfig{
|
||||
Proxy: "",
|
||||
ApiKey: "",
|
||||
Models: []string{},
|
||||
Filters: Filters{},
|
||||
|
||||
// mostly matches http.DefaultTransport but with a 60s ResponseHeader timeout
|
||||
// to match the pre PR #619 functionality
|
||||
Timeouts: TimeoutsConfig{
|
||||
Connect: 30,
|
||||
KeepAlive: 30,
|
||||
ResponseHeader: 60,
|
||||
TLSHandshake: 10,
|
||||
ExpectContinue: 1,
|
||||
IdleConn: 90,
|
||||
},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate proxy is not empty
|
||||
if defaults.Proxy == "" {
|
||||
return fmt.Errorf("proxy is required")
|
||||
}
|
||||
|
||||
// Validate proxy is a valid URL and store the parsed value
|
||||
parsedURL, err := url.Parse(defaults.Proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||
}
|
||||
defaults.ProxyURL = parsedURL
|
||||
|
||||
// Validate models is not empty
|
||||
if len(defaults.Models) == 0 {
|
||||
return fmt.Errorf("peer models can not be empty")
|
||||
}
|
||||
|
||||
*c = PeerConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
yaml: `
|
||||
proxy: http://192.168.1.23
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid config with apiKey",
|
||||
yaml: `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "missing proxy",
|
||||
yaml: `
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "empty proxy",
|
||||
yaml: `
|
||||
proxy: ""
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "invalid proxy URL",
|
||||
yaml: `
|
||||
proxy: "://invalid"
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "invalid peer proxy URL",
|
||||
},
|
||||
{
|
||||
name: "missing models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
models: []
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||
} else if !contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: http://192.168.1.23:8080/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.ProxyURL == nil {
|
||||
t.Fatal("ProxyURL should not be nil")
|
||||
}
|
||||
|
||||
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Scheme != "http" {
|
||||
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Path != "/api" {
|
||||
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||
}
|
||||
|
||||
func searchSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
provider:
|
||||
data_collection: deny
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
|
||||
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||
}
|
||||
|
||||
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider should be a map")
|
||||
}
|
||||
if provider["data_collection"] != "deny" {
|
||||
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
stripParams: "temperature, top_p"
|
||||
setParams:
|
||||
max_tokens: 1000
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check stripParams
|
||||
stripParams := config.Filters.SanitizedStripParams()
|
||||
if len(stripParams) != 2 {
|
||||
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||
}
|
||||
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||
t.Errorf("unexpected strip params: %v", stripParams)
|
||||
}
|
||||
|
||||
// Check setParams
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||
}
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_FindConfig(t *testing.T) {
|
||||
|
||||
// TODO?
|
||||
// make make this shared between the different tests
|
||||
config := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "python model1.py",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "python model2.py",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2", "model-two"},
|
||||
Env: []string{"VAR3=value3", "VAR4=value4"},
|
||||
CheckEndpoint: "/status",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 10,
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
// Test finding a model by its name
|
||||
modelConfig, modelId, found := config.FindConfig("model1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model by its alias
|
||||
modelConfig, modelId, found = config.FindConfig("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model that does not exist
|
||||
modelConfig, modelId, found = config.FindConfig("model3")
|
||||
assert.False(t, found)
|
||||
assert.Equal(t, "", modelId)
|
||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||
}
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
--arg3 123 \
|
||||
--arg4 '"string in string"'
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
// Package configwatcher provides a simple cross-platform file watcher based
|
||||
// on os.Stat polling. It works correctly inside Docker containers where the
|
||||
// config file is bind-mounted as an individual file, and for k8s ConfigMap
|
||||
// projections (which present the file as a symlink to an atomically swapped
|
||||
// target) — both cases where inotify-based watchers are unreliable.
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const DefaultInterval = 2 * time.Second
|
||||
|
||||
type Watcher struct {
|
||||
Path string
|
||||
Interval time.Duration
|
||||
OnChange func()
|
||||
}
|
||||
|
||||
type snapshot struct {
|
||||
exists bool
|
||||
modTime time.Time
|
||||
size int64
|
||||
}
|
||||
|
||||
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||
// OnChange whenever the file's modification time or size changes, or when
|
||||
// the file reappears after being missing. The baseline poll establishes
|
||||
// initial state and does not fire OnChange.
|
||||
func (w *Watcher) Run(ctx context.Context) {
|
||||
interval := w.Interval
|
||||
if interval <= 0 {
|
||||
interval = DefaultInterval
|
||||
}
|
||||
|
||||
prev := stat(w.Path)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cur := stat(w.Path)
|
||||
if changed(prev, cur) && w.OnChange != nil {
|
||||
w.OnChange()
|
||||
}
|
||||
prev = cur
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stat(path string) snapshot {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
log.Printf("configwatcher: stat %s: %v", path, err)
|
||||
}
|
||||
return snapshot{}
|
||||
}
|
||||
return snapshot{
|
||||
exists: true,
|
||||
modTime: fi.ModTime(),
|
||||
size: fi.Size(),
|
||||
}
|
||||
}
|
||||
|
||||
func changed(prev, cur snapshot) bool {
|
||||
// Present → missing: stay quiet (likely a transient rename-style write).
|
||||
// Missing → present: fire so we reload as soon as the file comes back.
|
||||
if !cur.exists {
|
||||
return false
|
||||
}
|
||||
if !prev.exists {
|
||||
return true
|
||||
}
|
||||
return !prev.modTime.Equal(cur.modTime) || prev.size != cur.size
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testInterval = 25 * time.Millisecond
|
||||
|
||||
// startWatcher launches w.Run in a goroutine and returns a function that
|
||||
// cancels the context and waits for Run to return.
|
||||
func startWatcher(t *testing.T, w *Watcher) func() {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("watcher did not stop within 2s of cancel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCount blocks until counter reaches want or timeout elapses.
|
||||
func waitForCount(t *testing.T, counter *int64, want int64, timeout time.Duration) bool {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if atomic.LoadInt64(counter) >= want {
|
||||
return true
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestWatcher_NoFireOnBaseline(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: path,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
|
||||
time.Sleep(testInterval * 5)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||
}
|
||||
|
||||
func TestWatcher_DetectsModTimeChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
// Force a known baseline mtime.
|
||||
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||
require.NoError(t, os.Chtimes(path, base, base))
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: path,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
|
||||
// Let the baseline settle.
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Bump mtime well above the baseline so low-resolution filesystems still notice.
|
||||
require.NoError(t, os.Chtimes(path, base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||
}
|
||||
|
||||
func TestWatcher_DetectsSizeChangeWithSameModTime(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
require.NoError(t, err)
|
||||
originalMtime := fi.ModTime()
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: path,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.WriteFile(path, []byte("aaaaa"), 0o644))
|
||||
// Reset mtime back to the original so size is the only signal.
|
||||
require.NoError(t, os.Chtimes(path, originalMtime, originalMtime))
|
||||
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire on size change")
|
||||
}
|
||||
|
||||
func TestWatcher_SymlinkTargetSwap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
targetA := filepath.Join(dir, "targetA")
|
||||
targetB := filepath.Join(dir, "targetB")
|
||||
link := filepath.Join(dir, "config.yaml")
|
||||
|
||||
require.NoError(t, os.WriteFile(targetA, []byte("AAAA"), 0o644))
|
||||
require.NoError(t, os.WriteFile(targetB, []byte("BBBBBBBB"), 0o644))
|
||||
|
||||
if err := os.Symlink(targetA, link); err != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skipf("symlink creation requires privilege on Windows: %v", err)
|
||||
}
|
||||
t.Fatalf("os.Symlink: %v", err)
|
||||
}
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: link,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Atomic symlink swap (k8s ConfigMap pattern): create new symlink at a
|
||||
// temp name, then rename over the existing one.
|
||||
tmpLink := filepath.Join(dir, "config.yaml.tmp")
|
||||
require.NoError(t, os.Symlink(targetB, tmpLink))
|
||||
require.NoError(t, os.Rename(tmpLink, link))
|
||||
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after symlink target swap")
|
||||
}
|
||||
|
||||
func TestWatcher_FileMissingThenReturns(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
var n int64
|
||||
stop := startWatcher(t, &Watcher{
|
||||
Path: path,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.Remove(path))
|
||||
time.Sleep(testInterval * 3)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "removal alone must not fire")
|
||||
|
||||
require.NoError(t, os.WriteFile(path, []byte("b"), 0o644))
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when file returns")
|
||||
}
|
||||
|
||||
func TestWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||
|
||||
w := &Watcher{Path: path, Interval: testInterval}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() { w.Run(ctx); close(done) }()
|
||||
|
||||
time.Sleep(testInterval * 2)
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run did not return within 2s of cancel")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package proxy
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||
type DiscardWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
}
|
||||
|
||||
// Satisfy the http.Flusher interface for streaming responses
|
||||
func (w *DiscardWriter) Flush() {}
|
||||
@@ -0,0 +1,69 @@
|
||||
package proxy
|
||||
|
||||
// package level registry of the different event types
|
||||
|
||||
const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
NewState ProcessState
|
||||
OldState ProcessState
|
||||
}
|
||||
|
||||
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||
return ProcessStateChangeEventID
|
||||
}
|
||||
|
||||
type ChatCompletionStats struct {
|
||||
TokensGenerated int
|
||||
}
|
||||
|
||||
func (e ChatCompletionStats) Type() uint32 {
|
||||
return ChatCompletionStatsEventID
|
||||
}
|
||||
|
||||
type ReloadingState int
|
||||
|
||||
const (
|
||||
ReloadingStateStart ReloadingState = iota
|
||||
ReloadingStateEnd
|
||||
)
|
||||
|
||||
type ConfigFileChangedEvent struct {
|
||||
ReloadingState ReloadingState
|
||||
}
|
||||
|
||||
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||
return ConfigFileChangedEventID
|
||||
}
|
||||
|
||||
type LogDataEvent struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (e LogDataEvent) Type() uint32 {
|
||||
return LogDataEventID
|
||||
}
|
||||
|
||||
type ModelPreloadedEvent struct {
|
||||
ModelName string
|
||||
Success bool
|
||||
}
|
||||
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
+246
-12
@@ -1,19 +1,30 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
simpleResponderPath = getSimpleResponderPath()
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
@@ -26,6 +37,17 @@ func TestMain(m *testing.M) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "debug":
|
||||
testLogger.SetLogLevel(LevelDebug)
|
||||
case "warn":
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
case "info":
|
||||
testLogger.SetLogLevel(LevelInfo)
|
||||
default:
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
@@ -33,26 +55,238 @@ func TestMain(m *testing.M) {
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
|
||||
if goos == "windows" {
|
||||
return filepath.Join("..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
func getTestPort() int {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
||||
return port
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
|
||||
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
|
||||
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||
t.Helper()
|
||||
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||
require.NoError(t, err)
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Create a process configuration
|
||||
return ModelConfig{
|
||||
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Convert path to forward slashes for cross-platform compatibility
|
||||
// Windows handles forward slashes in paths correctly
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, cmdPath, port, expectedMessage, port)
|
||||
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
|
||||
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
|
||||
// model IDs to their respond strings; if a model ID is not in the map, the model
|
||||
// ID itself is used.
|
||||
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
|
||||
for _, pg := range pm.processGroups {
|
||||
for modelID, process := range pg.processes {
|
||||
respond := modelID
|
||||
if r, ok := modelResponses[modelID]; ok {
|
||||
respond = r
|
||||
}
|
||||
process.testHandler = newTestHandler(respond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTestHandler returns an http.Handler that mimics simple-responder's API.
|
||||
// It supports the endpoints that routing tests depend on, without launching
|
||||
// any subprocess or binding any port.
|
||||
func newTestHandler(respond string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
isStreaming := r.URL.Query().Get("stream") == "true"
|
||||
|
||||
if wait := r.URL.Query().Get("wait"); wait != "" {
|
||||
if d, err := time.ParseDuration(wait); err == nil {
|
||||
time.Sleep(d)
|
||||
}
|
||||
}
|
||||
|
||||
if isStreaming {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
flusher := w.(http.Flusher)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
data, _ := json.Marshal(map[string]any{
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []map[string]any{
|
||||
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
finalData, _ := json.Marshal(map[string]any{
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
|
||||
flusher.Flush()
|
||||
|
||||
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"h_content_length": r.Header.Get("Content-Length"),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
"timings": map[string]any{
|
||||
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||
"predicted_ms": 17, "predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != respond {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"responseMessage": respond,
|
||||
"usage": map[string]any{
|
||||
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
model := r.FormValue("model")
|
||||
if model == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
fileBytes, _ := io.ReadAll(file)
|
||||
file.Close()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
|
||||
"model": model,
|
||||
"h_content_type": r.Header.Get("Content-Type"),
|
||||
"h_content_length": r.Header.Get("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
|
||||
model := r.URL.Query().Get("model")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"voices": []string{"voice1"}, "model": model,
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprint(w, respond)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": modelName, "images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"model": modelName, "images": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"loras": []string{},
|
||||
})
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Logs</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: "Courier New", Courier, monospace;
|
||||
}
|
||||
#log-stream {
|
||||
flex: 1;
|
||||
margin: 1em;
|
||||
padding: 10px;
|
||||
background: #f4f4f4;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap; /* Ensures line wrapping */
|
||||
word-wrap: break-word; /* Ensures long words wrap */
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<pre id="log-stream">Waiting for logs...
|
||||
</pre>
|
||||
|
||||
<script>
|
||||
// Establish an EventSource connection to the SSE endpoint
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
// Append the new log message to the <pre> element
|
||||
const logStream = document.getElementById('log-stream');
|
||||
|
||||
logStream.textContent += event.data;
|
||||
|
||||
// Auto-scroll to the bottom
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
};
|
||||
|
||||
eventSource.onerror = function(err) {
|
||||
console.error("EventSource failed:", err);
|
||||
};
|
||||
} else {
|
||||
console.error("SSE not supported by this browser.");
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
+217
-44
@@ -1,20 +1,121 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
// circularBuffer is a fixed-size circular byte buffer that overwrites
|
||||
// oldest data when full. It provides O(1) writes and O(n) reads.
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
|
||||
func newCircularBuffer(capacity int) *circularBuffer {
|
||||
return &circularBuffer{
|
||||
data: make([]byte, capacity),
|
||||
head: 0,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Write appends bytes to the buffer, overwriting oldest data when full.
|
||||
// Data is copied into the internal buffer (not stored by reference).
|
||||
func (cb *circularBuffer) Write(p []byte) {
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cap := len(cb.data)
|
||||
|
||||
// If input is larger than capacity, only keep the last cap bytes
|
||||
if len(p) >= cap {
|
||||
copy(cb.data, p[len(p)-cap:])
|
||||
cb.head = 0
|
||||
cb.size = cap
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate how much space is available from head to end of buffer
|
||||
firstPart := cap - cb.head
|
||||
if firstPart >= len(p) {
|
||||
// All data fits without wrapping
|
||||
copy(cb.data[cb.head:], p)
|
||||
cb.head = (cb.head + len(p)) % cap
|
||||
} else {
|
||||
// Data wraps around
|
||||
copy(cb.data[cb.head:], p[:firstPart])
|
||||
copy(cb.data[:len(p)-firstPart], p[firstPart:])
|
||||
cb.head = len(p) - firstPart
|
||||
}
|
||||
|
||||
// Update size
|
||||
cb.size += len(p)
|
||||
if cb.size > cap {
|
||||
cb.size = cap
|
||||
}
|
||||
}
|
||||
|
||||
// GetHistory returns all buffered data in correct order (oldest to newest).
|
||||
// Returns a new slice (copy), not a view into internal buffer.
|
||||
func (cb *circularBuffer) GetHistory() []byte {
|
||||
if cb.size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]byte, cb.size)
|
||||
cap := len(cb.data)
|
||||
|
||||
// Calculate start position (oldest data)
|
||||
start := (cb.head - cb.size + cap) % cap
|
||||
|
||||
if start+cb.size <= cap {
|
||||
// Data is contiguous, single copy
|
||||
copy(result, cb.data[start:start+cb.size])
|
||||
} else {
|
||||
// Data wraps around, two copies
|
||||
firstPart := cap - start
|
||||
copy(result[:firstPart], cb.data[start:])
|
||||
copy(result[firstPart:], cb.data[:cb.size-firstPart])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
LevelDebug LogLevel = iota
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
|
||||
LogBufferSize = 100 * 1024
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
clients map[chan []byte]bool
|
||||
eventbus *event.Dispatcher
|
||||
mu sync.RWMutex
|
||||
buffer *ring.Ring
|
||||
buffer *circularBuffer
|
||||
bufferMu sync.RWMutex
|
||||
|
||||
// typically this can be os.Stdout
|
||||
stdout io.Writer
|
||||
|
||||
// logging levels
|
||||
level LogLevel
|
||||
prefix string
|
||||
|
||||
// timestamps
|
||||
timeFormat string
|
||||
}
|
||||
|
||||
func NewLogMonitor() *LogMonitor {
|
||||
@@ -23,9 +124,12 @@ func NewLogMonitor() *LogMonitor {
|
||||
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
clients: make(map[chan []byte]bool),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: nil, // lazy initialized on first Write
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
timeFormat: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,57 +144,126 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
w.bufferMu.Lock()
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.buffer.Value = bufferCopy
|
||||
w.buffer = w.buffer.Next()
|
||||
if w.buffer == nil {
|
||||
w.buffer = newCircularBuffer(LogBufferSize)
|
||||
}
|
||||
w.buffer.Write(p)
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
w.broadcast(p)
|
||||
// Make a copy for broadcast to preserve immutability
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.broadcast(bufferCopy)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *LogMonitor) GetHistory() []byte {
|
||||
w.bufferMu.RLock()
|
||||
defer w.bufferMu.RUnlock()
|
||||
if w.buffer == nil {
|
||||
return nil
|
||||
}
|
||||
return w.buffer.GetHistory()
|
||||
}
|
||||
|
||||
var history []byte
|
||||
w.buffer.Do(func(p any) {
|
||||
if p != nil {
|
||||
if content, ok := p.([]byte); ok {
|
||||
history = append(history, content...)
|
||||
}
|
||||
}
|
||||
// Clear releases the buffer memory, making it eligible for GC.
|
||||
// The buffer will be lazily re-allocated on the next Write.
|
||||
func (w *LogMonitor) Clear() {
|
||||
w.bufferMu.Lock()
|
||||
w.buffer = nil
|
||||
w.bufferMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
|
||||
callback(e.Data)
|
||||
})
|
||||
return history
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Subscribe() chan []byte {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
ch := make(chan []byte, 100)
|
||||
w.clients[ch] = true
|
||||
return ch
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
delete(w.clients, ch)
|
||||
close(ch)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) broadcast(msg []byte) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
event.Publish(w.eventbus, LogDataEvent{Data: msg})
|
||||
}
|
||||
|
||||
for client := range w.clients {
|
||||
select {
|
||||
case client <- msg:
|
||||
default:
|
||||
// If client buffer is full, skip
|
||||
}
|
||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.prefix = prefix
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.level = level
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.timeFormat = timeFormat
|
||||
}
|
||||
|
||||
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||
prefix := ""
|
||||
if w.prefix != "" {
|
||||
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||
}
|
||||
timestamp := ""
|
||||
if w.timeFormat != "" {
|
||||
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||
if level < w.level {
|
||||
return
|
||||
}
|
||||
w.Write(w.formatMessage(level.String(), msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Debug(msg string) {
|
||||
w.log(LevelDebug, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Info(msg string) {
|
||||
w.log(LevelInfo, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Warn(msg string) {
|
||||
w.log(LevelWarn, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Error(msg string) {
|
||||
w.log(LevelError, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
|
||||
w.log(LevelDebug, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Infof(format string, args ...interface{}) {
|
||||
w.log(LevelInfo, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
|
||||
w.log(LevelWarn, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
|
||||
w.log(LevelError, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l LogLevel) String() string {
|
||||
switch l {
|
||||
case LevelDebug:
|
||||
return "DEBUG"
|
||||
case LevelInfo:
|
||||
return "INFO"
|
||||
case LevelWarn:
|
||||
return "WARN"
|
||||
case LevelError:
|
||||
return "ERROR"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
+243
-22
@@ -3,45 +3,38 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLogMonitor(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Test subscription
|
||||
client1 := logMonitor.Subscribe()
|
||||
client2 := logMonitor.Subscribe()
|
||||
|
||||
defer logMonitor.Unsubscribe(client1)
|
||||
defer logMonitor.Unsubscribe(client2)
|
||||
// A WaitGroup is used to wait for all the expected writes to complete
|
||||
var wg sync.WaitGroup
|
||||
|
||||
client1Messages := make([]byte, 0)
|
||||
client2Messages := make([]byte, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
defer logMonitor.OnLogData(func(data []byte) {
|
||||
client1Messages = append(client1Messages, data...)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case data := <-client1:
|
||||
client1Messages = append(client1Messages, data...)
|
||||
case data := <-client2:
|
||||
client2Messages = append(client2Messages, data...)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer logMonitor.OnLogData(func(data []byte) {
|
||||
client2Messages = append(client2Messages, data...)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
wg.Add(6) // 2 x 3 writes
|
||||
|
||||
logMonitor.Write([]byte("1"))
|
||||
logMonitor.Write([]byte("2"))
|
||||
logMonitor.Write([]byte("3"))
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
// wait for all writes to complete
|
||||
wg.Wait()
|
||||
|
||||
// Check the buffer
|
||||
@@ -93,3 +86,231 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
|
||||
t.Errorf("Expected history to be %q, got %q", expected, history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite_LogTimeFormat(t *testing.T) {
|
||||
// Create a new LogMonitor instance
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Enable timestamps
|
||||
lm.timeFormat = time.RFC3339
|
||||
|
||||
// Write the message to the LogMonitor
|
||||
lm.Info("Hello, World!")
|
||||
|
||||
// Get the history from the LogMonitor
|
||||
history := lm.GetHistory()
|
||||
|
||||
timestamp := ""
|
||||
fields := strings.Fields(string(history))
|
||||
if len(fields) > 0 {
|
||||
timestamp = fields[0]
|
||||
} else {
|
||||
t.Fatalf("Cannot extract string from history")
|
||||
}
|
||||
|
||||
_, err := time.Parse(time.RFC3339, timestamp)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot find timestamp: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_WrapAround(t *testing.T) {
|
||||
// Create a small buffer to test wrap-around
|
||||
cb := newCircularBuffer(10)
|
||||
|
||||
// Write "hello" (5 bytes)
|
||||
cb.Write([]byte("hello"))
|
||||
if got := string(cb.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Write "world" (5 bytes) - buffer now full
|
||||
cb.Write([]byte("world"))
|
||||
if got := string(cb.GetHistory()); got != "helloworld" {
|
||||
t.Errorf("Expected 'helloworld', got %q", got)
|
||||
}
|
||||
|
||||
// Write "12345" (5 bytes) - should overwrite "hello"
|
||||
cb.Write([]byte("12345"))
|
||||
if got := string(cb.GetHistory()); got != "world12345" {
|
||||
t.Errorf("Expected 'world12345', got %q", got)
|
||||
}
|
||||
|
||||
// Write data larger than buffer capacity
|
||||
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
|
||||
if got := string(cb.GetHistory()); got != "ghijklmnop" {
|
||||
t.Errorf("Expected 'ghijklmnop', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
|
||||
// Test empty buffer
|
||||
cb := newCircularBuffer(10)
|
||||
if got := cb.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil for empty buffer, got %q", got)
|
||||
}
|
||||
|
||||
// Test exact capacity
|
||||
cb.Write([]byte("1234567890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
|
||||
// Test write exactly at capacity boundary
|
||||
cb = newCircularBuffer(10)
|
||||
cb.Write([]byte("12345"))
|
||||
cb.Write([]byte("67890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_LazyInit(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Buffer should be nil before any writes
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil before first write")
|
||||
}
|
||||
|
||||
// GetHistory should return nil when buffer is nil
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history before first write, got %q", got)
|
||||
}
|
||||
|
||||
// Write should lazily initialize the buffer
|
||||
lm.Write([]byte("test"))
|
||||
|
||||
if lm.buffer == nil {
|
||||
t.Error("Expected buffer to be initialized after write")
|
||||
}
|
||||
|
||||
if got := string(lm.GetHistory()); got != "test" {
|
||||
t.Errorf("Expected 'test', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_Clear(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write some data
|
||||
lm.Write([]byte("hello"))
|
||||
if got := string(lm.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Clear should release the buffer
|
||||
lm.Clear()
|
||||
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil after Clear")
|
||||
}
|
||||
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history after Clear, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_ClearAndReuse(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write, clear, then write again
|
||||
lm.Write([]byte("first"))
|
||||
lm.Clear()
|
||||
lm.Write([]byte("second"))
|
||||
|
||||
if got := string(lm.GetHistory()); got != "second" {
|
||||
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogMonitorWrite(b *testing.B) {
|
||||
// Test data of varying sizes
|
||||
smallMsg := []byte("small message\n")
|
||||
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
|
||||
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
|
||||
|
||||
b.Run("SmallWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(smallMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("MediumWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LargeWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(largeMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WithSubscribers", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Add some subscribers
|
||||
for i := 0; i < 5; i++ {
|
||||
lm.OnLogData(func(data []byte) {})
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetHistory", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Pre-populate with data
|
||||
for i := 0; i < 1000; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.GetHistory()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
Benchmark Results - MBP M1 Pro
|
||||
|
||||
Before (ring.Ring):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 43 ns | 40 B | 2 |
|
||||
| MediumWrite (241B) | 76 ns | 264 B | 2 |
|
||||
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
|
||||
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
|
||||
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
|
||||
|
||||
After (circularBuffer 10KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 67 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
|
||||
|
||||
After (circularBuffer 100KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|-----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 66 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
|
||||
|
||||
Summary:
|
||||
- GetHistory: 139x faster (10KB), 18x faster (100KB)
|
||||
- Allocations: reduced from 2 to 1 across all operations
|
||||
- Small/medium writes: ~1.1-1.6x faster
|
||||
*/
|
||||
|
||||
+329
@@ -0,0 +1,329 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
// MatrixSolver contains pure swap-decision logic with no Process dependencies.
|
||||
// It is safe for concurrent reads after construction.
|
||||
type MatrixSolver struct {
|
||||
expandedSets []config.ExpandedSet // all valid model combinations
|
||||
evictCosts map[string]int // real model name -> eviction cost (default 1)
|
||||
modelToSets map[string][]int // model name -> indices into expandedSets
|
||||
}
|
||||
|
||||
// NewMatrixSolver builds a solver from expanded sets and eviction costs.
|
||||
func NewMatrixSolver(expandedSets []config.ExpandedSet, evictCosts map[string]int) *MatrixSolver {
|
||||
modelToSets := make(map[string][]int)
|
||||
for i, es := range expandedSets {
|
||||
for _, model := range es.Models {
|
||||
modelToSets[model] = append(modelToSets[model], i)
|
||||
}
|
||||
}
|
||||
|
||||
return &MatrixSolver{
|
||||
expandedSets: expandedSets,
|
||||
evictCosts: evictCosts,
|
||||
modelToSets: modelToSets,
|
||||
}
|
||||
}
|
||||
|
||||
// SolveResult describes what the solver decided.
|
||||
type SolveResult struct {
|
||||
Evict []string // running models that must be stopped
|
||||
TargetSet []string // the chosen set of models (for informational purposes)
|
||||
SetName string // name of the chosen set
|
||||
DSL string // original DSL expression for the chosen set
|
||||
TotalCost int // total eviction cost
|
||||
}
|
||||
|
||||
// Solve determines which models to evict when a model is requested.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. If requestedModel is already running, no eviction needed.
|
||||
// 2. Find all sets containing requestedModel.
|
||||
// 3. If no sets found, the model runs alone; evict all running models.
|
||||
// 4. For each candidate set, compute cost = sum of evict_costs for running
|
||||
// models NOT in that set.
|
||||
// 5. Pick lowest cost. Ties broken by definition order (index in expandedSets).
|
||||
// 6. Return models to evict and the chosen set.
|
||||
func (s *MatrixSolver) Solve(requestedModel string, runningModels []string) (SolveResult, error) {
|
||||
// If already running, nothing to do (but fill in set info for logging)
|
||||
if slices.Contains(runningModels, requestedModel) {
|
||||
setName, dsl := s.findMatchingSet(requestedModel, runningModels)
|
||||
return SolveResult{
|
||||
TargetSet: runningModels,
|
||||
SetName: setName,
|
||||
DSL: dsl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
candidateIndices := s.modelToSets[requestedModel]
|
||||
|
||||
// Model not in any set: runs alone, evict everything
|
||||
if len(candidateIndices) == 0 {
|
||||
evict := make([]string, len(runningModels))
|
||||
copy(evict, runningModels)
|
||||
return SolveResult{
|
||||
Evict: evict,
|
||||
TargetSet: []string{requestedModel},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Find the cheapest candidate set
|
||||
bestCost := -1
|
||||
bestIdx := -1
|
||||
|
||||
for _, idx := range candidateIndices {
|
||||
setModels := s.expandedSets[idx].Models
|
||||
cost := 0
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(setModels, running) {
|
||||
cost += s.evictCost(running)
|
||||
}
|
||||
}
|
||||
|
||||
if bestCost < 0 || cost < bestCost || (cost == bestCost && idx < bestIdx) {
|
||||
bestCost = cost
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// Determine which running models to evict
|
||||
chosen := s.expandedSets[bestIdx]
|
||||
var evict []string
|
||||
for _, running := range runningModels {
|
||||
if !slices.Contains(chosen.Models, running) {
|
||||
evict = append(evict, running)
|
||||
}
|
||||
}
|
||||
|
||||
return SolveResult{
|
||||
Evict: evict,
|
||||
TargetSet: chosen.Models,
|
||||
SetName: chosen.SetName,
|
||||
DSL: chosen.DSL,
|
||||
TotalCost: bestCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// findMatchingSet finds the expanded set that contains all running models.
|
||||
// Returns the set name and DSL, or empty strings if no match.
|
||||
func (s *MatrixSolver) findMatchingSet(requestedModel string, runningModels []string) (string, string) {
|
||||
for _, idx := range s.modelToSets[requestedModel] {
|
||||
set := s.expandedSets[idx]
|
||||
allInSet := true
|
||||
for _, m := range runningModels {
|
||||
if !slices.Contains(set.Models, m) {
|
||||
allInSet = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allInSet {
|
||||
return set.SetName, set.DSL
|
||||
}
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (s *MatrixSolver) evictCost(model string) int {
|
||||
if cost, ok := s.evictCosts[model]; ok {
|
||||
return cost
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// Matrix manages processes using solver-based swap logic.
|
||||
type Matrix struct {
|
||||
sync.Mutex
|
||||
solver *MatrixSolver
|
||||
processes map[string]*Process // all processes keyed by real model name
|
||||
config config.Config
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
|
||||
// inflight tracks ProxyRequest calls that have released m.Lock but may
|
||||
// not yet have incremented Process.inFlightRequests. A concurrent
|
||||
// request that needs to evict models waits for inflight to drain under
|
||||
// m.Lock before stopping anything. Without this, a request that
|
||||
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
|
||||
// races with Stop()'s Wait() and can be killed mid-request.
|
||||
inflight sync.WaitGroup
|
||||
|
||||
// testDelayFastPath is a test-only hook invoked in the no-eviction path
|
||||
// after m.Lock is released but before the request is dispatched to
|
||||
// Process.ProxyRequest. Tests use it to park a request at the exact
|
||||
// race window to deterministically reproduce the race.
|
||||
testDelayFastPath func()
|
||||
}
|
||||
|
||||
// NewMatrix creates a Matrix from config. It creates a Process for every
|
||||
// model defined in the config (any model can run alone even if not in a set).
|
||||
func NewMatrix(cfg config.Config, proxyLogger, upstreamLogger *LogMonitor) *Matrix {
|
||||
processes := make(map[string]*Process)
|
||||
for modelID, modelConfig := range cfg.Models {
|
||||
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, cfg.HealthCheckTimeout, modelConfig, processLogger, proxyLogger)
|
||||
processes[modelID] = process
|
||||
}
|
||||
|
||||
evictCosts := cfg.Matrix.ResolvedEvictCosts()
|
||||
|
||||
return &Matrix{
|
||||
solver: NewMatrixSolver(cfg.ExpandedSets, evictCosts),
|
||||
processes: processes,
|
||||
config: cfg,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyRequest handles the swap logic and proxies the request to the model.
|
||||
func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
process, ok := m.processes[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("model %s not found in matrix", modelID)
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
running := m.runningModels()
|
||||
result, err := m.solver.Solve(modelID, running)
|
||||
if err != nil {
|
||||
m.Unlock()
|
||||
return fmt.Errorf("matrix solver error: %w", err)
|
||||
}
|
||||
|
||||
// Log solver decision
|
||||
if len(result.Evict) > 0 {
|
||||
m.proxyLogger.Infof("Matrix: model=%s set=%s dsl=%q evict=%v target=%v cost=%d",
|
||||
modelID, result.SetName, result.DSL, result.Evict, result.TargetSet, result.TotalCost)
|
||||
} else if len(running) == 0 {
|
||||
m.proxyLogger.Infof("Matrix: model=%s starting (no models running)", modelID)
|
||||
} else {
|
||||
m.proxyLogger.Debugf("Matrix: model=%s already running in set=%s dsl=%q", modelID, result.SetName, result.DSL)
|
||||
}
|
||||
|
||||
// Evict models that need to be stopped
|
||||
if len(result.Evict) > 0 {
|
||||
// Wait for any in-flight ProxyRequest calls to register on their
|
||||
// Process before stopping anything. Without this, a request that
|
||||
// released m.Lock but has not yet incremented
|
||||
// Process.inFlightRequests races with Stop() and can be killed
|
||||
// mid-request.
|
||||
m.inflight.Wait()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, evictModel := range result.Evict {
|
||||
if p, exists := m.processes[evictModel]; exists {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
p.Stop()
|
||||
}(p)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Register this request in inflight before releasing m.Lock so a
|
||||
// concurrent eviction will wait for it to complete.
|
||||
m.inflight.Add(1)
|
||||
defer m.inflight.Done()
|
||||
isFastPath := len(result.Evict) == 0
|
||||
m.Unlock()
|
||||
|
||||
if isFastPath && m.testDelayFastPath != nil {
|
||||
m.testDelayFastPath()
|
||||
}
|
||||
|
||||
// Proxy the request (Process handles on-demand start)
|
||||
process.ProxyRequest(w, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProcesses stops all running processes.
|
||||
func (m *Matrix) StopProcesses(strategy StopStrategy) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range m.processes {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
p.StopImmediately()
|
||||
default:
|
||||
p.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// StopProcess stops a single process by model ID.
|
||||
func (m *Matrix) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
process, ok := m.processes[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down all processes.
|
||||
func (m *Matrix) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range m.processes {
|
||||
wg.Add(1)
|
||||
go func(p *Process) {
|
||||
defer wg.Done()
|
||||
p.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// RunningModels returns model names currently in an active (non-stopped) state.
|
||||
func (m *Matrix) RunningModels() []string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.runningModels()
|
||||
}
|
||||
|
||||
// runningModels returns running model names (caller must hold lock).
|
||||
func (m *Matrix) runningModels() []string {
|
||||
var running []string
|
||||
for id, process := range m.processes {
|
||||
if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
|
||||
running = append(running, id)
|
||||
}
|
||||
}
|
||||
sort.Strings(running)
|
||||
return running
|
||||
}
|
||||
|
||||
// GetProcess returns the Process for a model.
|
||||
func (m *Matrix) GetProcess(modelID string) (*Process, bool) {
|
||||
p, ok := m.processes[modelID]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// HasModel returns true if the model is managed by this matrix.
|
||||
func (m *Matrix) HasModel(modelID string) bool {
|
||||
_, ok := m.processes[modelID]
|
||||
return ok
|
||||
}
|
||||
@@ -0,0 +1,349 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper to build expanded sets for solver tests
|
||||
func makeExpandedSets(sets ...struct {
|
||||
name string
|
||||
models []string
|
||||
}) []config.ExpandedSet {
|
||||
var result []config.ExpandedSet
|
||||
for _, s := range sets {
|
||||
result = append(result, config.ExpandedSet{
|
||||
SetName: s.name,
|
||||
Models: s.models,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func es(name string, models ...string) struct {
|
||||
name string
|
||||
models []string
|
||||
} {
|
||||
return struct {
|
||||
name string
|
||||
models []string
|
||||
}{name, models}
|
||||
}
|
||||
|
||||
func TestMatrixSolver_AlreadyRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("a", []string{"a"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a"}, result.TargetSet)
|
||||
assert.Equal(t, "s1", result.SetName)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NotInAnySet_RunsAlone(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Model "c" not in any set
|
||||
result, err := solver.Solve("c", []string{"a", "b"})
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"a", "b"}, result.Evict)
|
||||
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NotInAnySet_NothingRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("c", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"c"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_SingleSet_EvictsNonMembers(t *testing.T) {
|
||||
// Set: [a, b]. Request a when b and c are running.
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(es("s1", "a", "b")),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("a", []string{"b", "c"})
|
||||
require.NoError(t, err)
|
||||
// c is not in the set, so it gets evicted. b is in the set, so it stays.
|
||||
assert.Equal(t, []string{"c"}, result.Evict)
|
||||
assert.Equal(t, []string{"a", "b"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_PicksLowestCost(t *testing.T) {
|
||||
// Two sets containing model "a":
|
||||
// s1: [a, v] — if v is running, cost=0; if L is running, cost=30
|
||||
// s2: [a, L] — if L is running, cost=0; if v is running, cost=50
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "a", "v"),
|
||||
es("s2", "a", "L"),
|
||||
),
|
||||
map[string]int{"v": 50, "L": 30},
|
||||
)
|
||||
|
||||
// v is running. Switching to a:
|
||||
// s1 cost: v is in s1, so 0
|
||||
// s2 cost: v is NOT in s2, so 50
|
||||
// => pick s1
|
||||
result, err := solver.Solve("a", []string{"v"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "v"}, result.TargetSet)
|
||||
|
||||
// L is running. Switching to a:
|
||||
// s1 cost: L is NOT in s1, so 30
|
||||
// s2 cost: L is in s2, so 0
|
||||
// => pick s2
|
||||
result, err = solver.Solve("a", []string{"L"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "L"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_TieBreakingByDefinitionOrder(t *testing.T) {
|
||||
// Two sets with identical cost. Definition order should win.
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "a", "x"),
|
||||
es("s2", "a", "y"),
|
||||
),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Nothing running, both sets cost 0. s1 is first.
|
||||
result, err := solver.Solve("a", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"a", "x"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_EvictCostPreservesExpensive(t *testing.T) {
|
||||
// Model "v" costs 50 to evict, "m" costs 1 (default).
|
||||
// Sets: [g,v], [g,m]
|
||||
// Running: v, m. Request g.
|
||||
// s1=[g,v]: evict m (cost 1), keep v
|
||||
// s2=[g,m]: evict v (cost 50), keep m
|
||||
// => pick s1
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "g", "v"),
|
||||
es("s2", "g", "m"),
|
||||
),
|
||||
map[string]int{"v": 50},
|
||||
)
|
||||
|
||||
result, err := solver.Solve("g", []string{"v", "m"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"m"}, result.Evict)
|
||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||
}
|
||||
|
||||
func TestMatrixSolver_NothingRunning(t *testing.T) {
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("s1", "g", "v"),
|
||||
es("s2", "q", "v"),
|
||||
),
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := solver.Solve("g", []string{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Evict)
|
||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||
}
|
||||
|
||||
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
|
||||
// cannot stop a process while an in-flight ProxyRequest for that process is
|
||||
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
|
||||
// matrix-level inflight tracking, the eviction's Stop() races with the
|
||||
// pending request and kills it mid-start.
|
||||
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
ExpandedSets: []config.ExpandedSet{
|
||||
{SetName: "s1", Models: []string{"model1"}},
|
||||
{SetName: "s2", Models: []string{"model2"}},
|
||||
},
|
||||
Matrix: &config.MatrixConfig{},
|
||||
}
|
||||
|
||||
m := NewMatrix(cfg, testLogger, testLogger)
|
||||
defer m.StopProcesses(StopImmediately)
|
||||
|
||||
// Bypass real subprocesses so the test is fast and deterministic.
|
||||
m.processes["model1"].testHandler = newTestHandler("model1")
|
||||
m.processes["model2"].testHandler = newTestHandler("model2")
|
||||
|
||||
// Prime: run a request through model1 so it reaches StateReady and
|
||||
// subsequent requests take the no-eviction path.
|
||||
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
primeW := httptest.NewRecorder()
|
||||
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
|
||||
require.Equal(t, http.StatusOK, primeW.Code)
|
||||
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
|
||||
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
|
||||
|
||||
// Install fast-path hook that signals arrival and waits for release.
|
||||
// This parks R2 at the race window — after m.Lock is released but
|
||||
// before Process.inFlightRequests.Add(1).
|
||||
r2Reached := make(chan struct{})
|
||||
r2Release := make(chan struct{})
|
||||
m.testDelayFastPath = func() {
|
||||
close(r2Reached)
|
||||
<-r2Release
|
||||
}
|
||||
|
||||
// R2: no-eviction request for model1. Will pause at the hook.
|
||||
r2Done := make(chan struct{})
|
||||
w2 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r2Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, m.ProxyRequest("model1", w2, req))
|
||||
}()
|
||||
|
||||
// Deterministically wait for R2 to reach the race window.
|
||||
<-r2Reached
|
||||
|
||||
// R3: request for model2 which requires evicting model1. Must wait for
|
||||
// R2 to finish before touching model1.
|
||||
r3Done := make(chan struct{})
|
||||
w3 := httptest.NewRecorder()
|
||||
go func() {
|
||||
defer close(r3Done)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
assert.NoError(t, m.ProxyRequest("model2", w3, req))
|
||||
}()
|
||||
|
||||
// Spin until R3 has acquired m.Lock and entered the eviction path. In
|
||||
// the fixed code, R3 then blocks on m.inflight.Wait() while still
|
||||
// holding the lock, so TryLock keeps failing.
|
||||
for m.TryLock() {
|
||||
m.Unlock()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||
// state. In the fixed code R3 is blocked and nothing changes; in the
|
||||
// buggy code R3 will Stop() model1 and start model2 within microseconds.
|
||||
deadline := time.Now().Add(100 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if m.processes["model1"].CurrentState() != StateReady ||
|
||||
m.processes["model2"].CurrentState() != StateStopped {
|
||||
break
|
||||
}
|
||||
done := false
|
||||
select {
|
||||
case <-r3Done:
|
||||
done = true
|
||||
default:
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||
select {
|
||||
case <-r3Done:
|
||||
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
|
||||
default:
|
||||
}
|
||||
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
|
||||
"model1 must stay Ready while an in-flight request is pending")
|
||||
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
|
||||
"model2 must not be started until R2 finishes and model1 is evicted")
|
||||
|
||||
// Release R2 and let both requests finish.
|
||||
close(r2Release)
|
||||
<-r2Done
|
||||
<-r3Done
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Contains(t, w2.Body.String(), "model1")
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Contains(t, w3.Body.String(), "model2")
|
||||
}
|
||||
|
||||
func TestMatrixSolver_FullScenario(t *testing.T) {
|
||||
// Simulates the example config:
|
||||
// standard: [g,v], [q,v], [m,v]
|
||||
// with_rerank: [g,v,e], [q,v,e]
|
||||
// creative: [g,sd], [q,sd]
|
||||
// full: [L]
|
||||
solver := NewMatrixSolver(
|
||||
makeExpandedSets(
|
||||
es("standard", "g", "v"),
|
||||
es("standard", "q", "v"),
|
||||
es("standard", "m", "v"),
|
||||
es("with_rerank", "e", "g", "v"),
|
||||
es("with_rerank", "e", "q", "v"),
|
||||
es("creative", "g", "sd"),
|
||||
es("creative", "q", "sd"),
|
||||
es("full", "L"),
|
||||
),
|
||||
map[string]int{"v": 50, "L": 30, "whisper": 10},
|
||||
)
|
||||
|
||||
// Running: g, v. Request q.
|
||||
// standard[q,v]: evict g (cost 1), keep v. Total: 1.
|
||||
// with_rerank[q,v,e]: evict g (cost 1), keep v. Total: 1.
|
||||
// => tie, pick first by definition order = standard[q,v]
|
||||
result, err := solver.Solve("q", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"g"}, result.Evict)
|
||||
assert.Equal(t, []string{"q", "v"}, result.TargetSet)
|
||||
|
||||
// Running: g, v. Request L.
|
||||
// full[L]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||
// Only one set contains L, so pick it.
|
||||
result, err = solver.Solve("L", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"g", "v"}, result.Evict)
|
||||
assert.Equal(t, []string{"L"}, result.TargetSet)
|
||||
|
||||
// Running: g, v. Request sd.
|
||||
// creative[g,sd]: evict v (cost 50). Total: 50.
|
||||
// creative[q,sd]: evict g (cost 1) + v (cost 50). Total: 51.
|
||||
// => pick creative[g,sd]
|
||||
result, err = solver.Solve("sd", []string{"g", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"v"}, result.Evict)
|
||||
assert.Equal(t, []string{"g", "sd"}, result.TargetSet)
|
||||
|
||||
// Running: q, v, e. Request g.
|
||||
// standard[g,v]: evict q (1) + e (1). Total: 2.
|
||||
// with_rerank[g,v,e]: evict q (1). Total: 1.
|
||||
// creative[g,sd]: evict q (1) + v (50) + e (1). Total: 52.
|
||||
// => pick with_rerank[g,v,e]
|
||||
result, err = solver.Solve("g", []string{"e", "q", "v"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"q"}, result.Evict)
|
||||
assert.Equal(t, []string{"e", "g", "v"}, result.TargetSet)
|
||||
}
|
||||
@@ -0,0 +1,592 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||
var zstdEncOptions = []zstd.EOption{
|
||||
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||
}
|
||||
|
||||
// zstdDecOptions are the shared zstd decoder options.
|
||||
var zstdDecOptions = []zstd.DOption{}
|
||||
|
||||
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||
var zstdEncPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||
return enc
|
||||
},
|
||||
}
|
||||
|
||||
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||
var zstdDecPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
|
||||
return dec
|
||||
},
|
||||
}
|
||||
|
||||
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
|
||||
// Returns compressed bytes and the original JSON byte count for logging.
|
||||
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||
jsonBytes, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||
}
|
||||
enc := zstdEncPool.Get().(*zstd.Encoder)
|
||||
defer zstdEncPool.Put(enc)
|
||||
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
|
||||
}
|
||||
|
||||
// decompressCapture decompresses zstd-compressed JSON and returns it.
|
||||
func decompressCapture(data []byte) ([]byte, error) {
|
||||
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||
defer zstdDecPool.Put(dec)
|
||||
return dec.DecodeAll(data, nil)
|
||||
}
|
||||
|
||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||
type TokenMetrics struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
type TokenMetricsEvent struct {
|
||||
Metrics TokenMetrics
|
||||
}
|
||||
|
||||
func (e TokenMetricsEvent) Type() uint32 {
|
||||
return TokenMetricsEventID // defined in events.go
|
||||
}
|
||||
|
||||
// metricsMonitor parses llama-server output for token statistics
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics []TokenMetrics
|
||||
maxMetrics int
|
||||
nextID int
|
||||
logger *LogMonitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
|
||||
captureOrder []int // track insertion order for FIFO eviction
|
||||
captureSize int // current total compressed size in bytes
|
||||
maxCaptureSize int // max bytes for captures (uncompressed)
|
||||
}
|
||||
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
return &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
captures: make(map[int][]byte),
|
||||
captureOrder: make([]int, 0),
|
||||
captureSize: 0,
|
||||
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||
}
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event.
|
||||
// Returns the assigned metric ID.
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics = append(mp.metrics, metric)
|
||||
if len(mp.metrics) > mp.maxMetrics {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||
if !mp.enableCaptures {
|
||||
return
|
||||
}
|
||||
|
||||
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||
return
|
||||
}
|
||||
|
||||
captureSize := len(compressed)
|
||||
if captureSize > mp.maxCaptureSize {
|
||||
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||
return
|
||||
}
|
||||
|
||||
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
|
||||
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Evict oldest (FIFO) until room available for the compressed data
|
||||
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||
oldestID := mp.captureOrder[0]
|
||||
mp.captureOrder = mp.captureOrder[1:]
|
||||
if evicted, exists := mp.captures[oldestID]; exists {
|
||||
l := len(evicted)
|
||||
mp.captureSize -= l
|
||||
delete(mp.captures, oldestID)
|
||||
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
|
||||
}
|
||||
}
|
||||
|
||||
mp.captures[capture.ID] = compressed
|
||||
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||
mp.captureSize += captureSize
|
||||
|
||||
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||
}
|
||||
|
||||
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
||||
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
data, exists := mp.captures[id]
|
||||
return data, exists
|
||||
}
|
||||
|
||||
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
|
||||
// If decompress=false, returns the raw zstd-compressed bytes.
|
||||
// Returns nil if the capture is not found.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
data, exists := mp.captures[id]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !decompress {
|
||||
return data
|
||||
}
|
||||
|
||||
decompressed, err := decompressCapture(data)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return decompressed
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics
|
||||
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := make([]TokenMetrics, len(mp.metrics))
|
||||
copy(result, mp.metrics)
|
||||
return result
|
||||
}
|
||||
|
||||
// getMetricsJSON returns metrics as JSON
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
return json.Marshal(mp.metrics)
|
||||
}
|
||||
|
||||
// wrapHandler wraps the proxy handler to extract token metrics
|
||||
// if wrapHandler returns an error it is safe to assume that no
|
||||
// data was sent to the client
|
||||
func (mp *metricsMonitor) wrapHandler(
|
||||
modelID string,
|
||||
writer gin.ResponseWriter,
|
||||
request *http.Request,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||
}
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
reqHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// after this point we have to assume that data was sent to the client
|
||||
// and we can only log errors but not send them to clients
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize default metrics - these will always be recorded
|
||||
tm := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsed
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// extract timings for infill - response is an array, timings are in the last element
|
||||
// see #463
|
||||
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsedMetrics
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
respHeaders := make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: body,
|
||||
}
|
||||
compressed, _, err := compressCapture(capture)
|
||||
if err == nil && len(compressed) <= mp.maxCaptureSize {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.addMetrics(tm)
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
mp.addCapture(*capture)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
|
||||
// Iterate **backwards** through the body looking for the data payload with
|
||||
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
||||
|
||||
// Start from the end of the body and scan backwards for newlines
|
||||
pos := len(body)
|
||||
for pos > 0 {
|
||||
// Find the previous newline (or start of body)
|
||||
lineStart := bytes.LastIndexByte(body[:pos], '\n')
|
||||
if lineStart == -1 {
|
||||
lineStart = 0
|
||||
} else {
|
||||
lineStart++ // Move past the newline
|
||||
}
|
||||
|
||||
line := bytes.TrimSpace(body[lineStart:pos])
|
||||
pos = lineStart - 1 // Move position before the newline for next iteration
|
||||
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// SSE payload always follows "data:"
|
||||
prefix := []byte("data:")
|
||||
if !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
// [DONE] line itself contains nothing of interest.
|
||||
continue
|
||||
}
|
||||
|
||||
if gjson.ValidBytes(data) {
|
||||
parsed := gjson.ParseBytes(data)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// v1/responses format nests usage under response.usage
|
||||
if !usage.Exists() {
|
||||
usage = parsed.Get("response.usage")
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
return parseMetrics(modelID, start, usage, timings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
||||
wallDurationMs := int(time.Since(start).Milliseconds())
|
||||
|
||||
// default values
|
||||
cachedTokens := -1 // unknown or missing data
|
||||
outputTokens := 0
|
||||
inputTokens := 0
|
||||
|
||||
// timings data
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
durationMs := wallDurationMs
|
||||
|
||||
if usage.Exists() {
|
||||
if pt := usage.Get("prompt_tokens"); pt.Exists() {
|
||||
// v1/chat/completions
|
||||
inputTokens = int(pt.Int())
|
||||
} else if it := usage.Get("input_tokens"); it.Exists() {
|
||||
// v1/messages
|
||||
inputTokens = int(it.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("completion_tokens"); ct.Exists() {
|
||||
// v1/chat/completions
|
||||
outputTokens = int(ct.Int())
|
||||
} else if ot := usage.Get("output_tokens"); ot.Exists() {
|
||||
outputTokens = int(ot.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
|
||||
cachedTokens = int(ct.Int())
|
||||
}
|
||||
}
|
||||
|
||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||
if timings.Exists() {
|
||||
inputTokens = int(timings.Get("prompt_n").Int())
|
||||
outputTokens = int(timings.Get("predicted_n").Int())
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
timingsDurationMs := int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
if timingsDurationMs > durationMs {
|
||||
durationMs = timingsDurationMs
|
||||
}
|
||||
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = int(cachedValue.Int())
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
CachedTokens: cachedTokens,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
DurationMs: durationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
tee io.Writer
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||
bodyBuffer := &bytes.Buffer{}
|
||||
return &responseBodyCopier{
|
||||
ResponseWriter: w,
|
||||
body: bodyBuffer,
|
||||
tee: io.MultiWriter(w, bodyBuffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
if w.start.IsZero() {
|
||||
w.start = time.Now()
|
||||
}
|
||||
|
||||
// Single write operation that writes to both the response and buffer
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that should be redacted in captures
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for part := range strings.SplitSeq(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding, _, _ := strings.Cut(strings.TrimSpace(part), ";")
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,143 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
|
||||
// Create a transport with per-peer timeout configuration
|
||||
peerTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
||||
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
||||
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
||||
}
|
||||
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_CustomTimeouts(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://localhost:8080")
|
||||
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"test-peer": config.PeerConfig{
|
||||
Proxy: "http://localhost:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"model1"},
|
||||
Timeouts: config.TimeoutsConfig{
|
||||
Connect: 45,
|
||||
ResponseHeader: 300,
|
||||
TLSHandshake: 15,
|
||||
ExpectContinue: 2,
|
||||
IdleConn: 120,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(peers, testLogger)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, peerProxy)
|
||||
assert.True(t, peerProxy.HasPeerModel("model1"))
|
||||
|
||||
// Verify the timeout values are actually applied to the transport
|
||||
member, found := peerProxy.proxyMap["model1"]
|
||||
require.True(t, found, "model1 should exist in proxyMap")
|
||||
assert.NotNil(t, member.reverseProxy)
|
||||
assert.NotNil(t, member.reverseProxy.Transport)
|
||||
|
||||
transport, ok := member.reverseProxy.Transport.(*http.Transport)
|
||||
require.True(t, ok, "Transport should be *http.Transport")
|
||||
|
||||
// Verify all timeout values are correctly applied
|
||||
assert.Equal(t, 300*time.Second, transport.ResponseHeaderTimeout)
|
||||
assert.Equal(t, 15*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Equal(t, 2*time.Second, transport.ExpectContinueTimeout)
|
||||
assert.Equal(t, 120*time.Second, transport.IdleConnTimeout)
|
||||
// ForceAttemptHTTP2 should be enabled
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
+840
-217
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user