Compare commits
189 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b4ee70154 | |||
| b5fde8eb6d | |||
| 7eef5defb8 | |||
| bc01e6f539 | |||
| 0462e3dc3f | |||
| 7b20fc011b | |||
| 20738f3623 | |||
| cdea7d16bd | |||
| 5de387dbf9 | |||
| 6f8e7ccb57 | |||
| 4384315b44 | |||
| 6439ab1515 | |||
| f94226122c | |||
| 7493618fdc | |||
| 205efd40a1 | |||
| 14207f8492 | |||
| 4e850c2834 | |||
| 75fced579e | |||
| b73f367f22 | |||
| 8f2137c72b | |||
| 124007cc98 | |||
| eb5bfff0b0 | |||
| 3edb180c08 | |||
| 66d555e625 | |||
| 4f863fd9fc | |||
| 267c030457 | |||
| c19309fe7e | |||
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f | |||
| 37d74efc2d | |||
| 22e098ac8b | |||
| 9864f9f517 | |||
| 53b32f3601 | |||
| 565c44766d | |||
| e6a9e210ba | |||
| d3f329f924 | |||
| 98879b38c1 | |||
| 7b3b0f5eae | |||
| 021ccceef1 | |||
| f03871c50a | |||
| dc00d17abe | |||
| dea98733c3 | |||
| bccce5fa19 | |||
| c968da1b73 | |||
| a883d68d4f | |||
| b1dec8b735 | |||
| 06523d8c1e | |||
| 86e9b93c37 | |||
| 3acace810f | |||
| 554d29e87d | |||
| 3567b7df08 | |||
| 38738525c9 | |||
| c0fc858193 | |||
| b429349e8a | |||
| eab2efd7b5 | |||
| 6aedbe121a | |||
| b24467ab89 | |||
| 12b69fb718 | |||
| f91a8b2462 | |||
| a89b803d4a | |||
| f852689104 | |||
| e250e71e59 | |||
| d18dc26d01 | |||
| 8357714421 | |||
| c07179d6e2 | |||
| 7ff50631e0 | |||
| 9fc0431531 | |||
| 6516532568 | |||
| d58a8b85bf | |||
| caf9e98b1e | |||
| 539278343b | |||
| 00b738cd0f | |||
| 70930e4e91 | |||
| 1f6179110c | |||
| 216c40b951 | |||
| 9e3d491c85 | |||
| 1a84926505 | |||
| fc3bb716df | |||
| c36986fef6 | |||
| 558801db1a | |||
| b21dee27c1 | |||
| f58c8c8ec5 | |||
| 954e2dee73 | |||
| a533aec736 | |||
| 97b17fc47d | |||
| 2457840698 | |||
| 7f55494151 | |||
| 831a90d3b0 | |||
| 977f1856bb | |||
| 52b329f7bc | |||
| 57803fd3aa | |||
| c55d0cc842 | |||
| 7acbaf4712 | |||
| fcc5ad135a | |||
| 305e5a0031 | |||
| 04fc67354a | |||
| 4662cf7699 | |||
| 5dc6b3e6d9 | |||
| 74c69f39ef | |||
| a186318892 | |||
| c4e4d5e1e9 | |||
| 7985e94ba4 | |||
| 74556c3a36 | |||
| 5c381e4b30 | |||
| 10569ed546 | |||
| 5b10b3c23f | |||
| 45ea792a3a | |||
| 1bc2802353 | |||
| 701476c0c4 | |||
| 5c63e0066c | |||
| 8be5073c51 | |||
| 6307bd3205 | |||
| 558a72de17 | |||
| dc42cf366d | |||
| ba0a81937a | |||
| 574fdfabb4 | |||
| 5172cb2e12 | |||
| 5672cb03fd | |||
| 0f583163f7 | |||
| 7905fa9ea3 | |||
| bbaf172956 | |||
| fd50932dbc | |||
| 8c693e7fcf | |||
| 8f2af26a41 | |||
| 01d4838fb3 | |||
| accd65294b | |||
| 7472a25864 | |||
| cce0bc6aa1 | |||
| 36e25125e8 | |||
| 9a54273d15 | |||
| 87dce5f8f6 | |||
| 307e619521 | |||
| 6299c1b874 | |||
| a906cd459b | |||
| 78b2bc3dbc | |||
| 6a058e4191 | |||
| 1921e570d7 | |||
| c867a6c9a2 | |||
| 3bd1b23ce0 | |||
| 10606abf89 | |||
| fefd14903d | |||
| 717d64e336 | |||
| 285191e655 | |||
| 4236cec03a | |||
| 756193d0dd | |||
| a6b2e930d8 | |||
| 9e02c22ff8 | |||
| 0bdbf2fdc1 | |||
| 49035e2e8e | |||
| 9963ae18bf | |||
| 2ae48c713b | |||
| 54c519e365 | |||
| 3fce9ee0e9 | |||
| 5899ae7966 | |||
| 591a9cdf4d | |||
| 9a3c656738 | |||
| 75015f82ea | |||
| cc33b6c270 | |||
| 4fa12a429c | |||
| 2dc0ca0663 | |||
| a84098d3b4 | |||
| 4d02ccd26a | |||
| dfd47eeac4 | |||
| 1ac6499c08 | |||
| 25f3dc25e7 | |||
| 8422e4e6a1 | |||
| 02ee29d881 | |||
| b2a891f8f4 | |||
| 8d2b568897 | |||
| fb44cf4e08 | |||
| 02aee4e86d | |||
| f45896d395 | |||
| f7e46a359f | |||
| c260907415 | |||
| b83a5fa291 | |||
| 6e2ff28d59 | |||
| a8b81f2799 | |||
| f9ee7156dc | |||
| 2d00120781 | |||
| afc9aef058 | |||
| d7b390df74 | |||
| 5025c2f1f3 | |||
| e3a0b013c1 | |||
| f5763a94a0 | |||
| 8ada72eb57 | |||
| 2441b383d3 |
@@ -8,8 +8,15 @@ reviews:
|
|||||||
poem: false
|
poem: false
|
||||||
review_status: true
|
review_status: true
|
||||||
collapse_walkthrough: false
|
collapse_walkthrough: false
|
||||||
|
sequence_diagrams: false
|
||||||
|
finishing_touches:
|
||||||
|
docstrings:
|
||||||
|
enabled: false
|
||||||
auto_review:
|
auto_review:
|
||||||
enabled: true
|
enabled: true
|
||||||
drafts: false
|
drafts: false
|
||||||
chat:
|
chat:
|
||||||
auto_reply: true
|
auto_reply: true
|
||||||
|
issue_enrichment:
|
||||||
|
planning:
|
||||||
|
enabled: false
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
---
|
||||||
|
name: Bug Report
|
||||||
|
about: I found a defect
|
||||||
|
title: ''
|
||||||
|
labels: 'unconfirmed bug'
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> If you have questions about llama-swap please post in the Q&A in Discussions. Use bug reports when you've found a defect and wish to discuss a fix.
|
||||||
|
|
||||||
|
**Describe the bug**
|
||||||
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
|
**Expected behaviour**
|
||||||
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Operating system and version**
|
||||||
|
|
||||||
|
- OS: (linux, osx, windows, freebsd, etc)
|
||||||
|
- GPUs: (list architecture)
|
||||||
|
|
||||||
|
**My Configuration**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# copy / paste your configuration here
|
||||||
|
```
|
||||||
|
|
||||||
|
**Proxy Logs**
|
||||||
|
|
||||||
|
```
|
||||||
|
# copy / paste from /logs
|
||||||
|
```
|
||||||
|
|
||||||
|
**Upstream Logs**
|
||||||
|
|
||||||
|
```
|
||||||
|
# copy/paste from /logs
|
||||||
|
```
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
name: Validate JSON Schema
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "config-schema.json"
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- "config-schema.json"
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
validate-schema:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Validate JSON Schema
|
||||||
|
run: |
|
||||||
|
# Check if the file is valid JSON
|
||||||
|
if ! jq empty config-schema.json 2>/dev/null; then
|
||||||
|
echo "Error: config-schema.json is not valid JSON"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Validate that it's a valid JSON Schema
|
||||||
|
# Check for required $schema field
|
||||||
|
if ! jq -e '."$schema"' config-schema.json > /dev/null; then
|
||||||
|
echo "Warning: config-schema.json should have a \$schema field"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check that it has either properties or definitions
|
||||||
|
if ! jq -e '.properties or .definitions or ."$defs"' config-schema.json > /dev/null; then
|
||||||
|
echo "Warning: JSON Schema should contain properties, definitions, or \$defs"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✓ config-schema.json is valid"
|
||||||
@@ -10,17 +10,37 @@ on:
|
|||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
# Run on workflow file changes (without pushing)
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/containers.yml'
|
||||||
|
- 'docker/build-container.sh'
|
||||||
|
- 'docker/*.Containerfile'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-and-push:
|
build-and-push:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
platform: [intel, cuda, vulkan, cpu, musa]
|
platform: [intel, cuda, vulkan, cpu, musa, rocm]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Free up disk space
|
||||||
|
if: matrix.platform == 'rocm'
|
||||||
|
run: |
|
||||||
|
echo "Before cleanup:"
|
||||||
|
df -h
|
||||||
|
sudo rm -rf /usr/share/dotnet
|
||||||
|
sudo rm -rf /usr/local/lib/android
|
||||||
|
sudo rm -rf /opt/ghc
|
||||||
|
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||||
|
sudo docker system prune -af
|
||||||
|
echo "After cleanup:"
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Log in to GitHub Container Registry
|
- name: Log in to GitHub Container Registry
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
@@ -31,7 +51,7 @@ jobs:
|
|||||||
- name: Run build-container
|
- name: Run build-container
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
|
||||||
|
|
||||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
name: Windows CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
# only run when backend source changes
|
||||||
|
# cmd/ is excluded because it contains utilities without tests
|
||||||
|
paths:
|
||||||
|
- '**/*.go'
|
||||||
|
- '!cmd/**'
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- 'Makefile'
|
||||||
|
- '.github/workflows/go-ci-windows.yml'
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
paths:
|
||||||
|
- '**/*.go'
|
||||||
|
- '!cmd/**'
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- 'Makefile'
|
||||||
|
- '.github/workflows/go-ci-windows.yml'
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
run-tests:
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# cache simple-responder to save the build time
|
||||||
|
- name: Restore Simple Responder
|
||||||
|
id: restore-simple-responder
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
|
# necessary for testing proxy/Process swapping
|
||||||
|
- name: Create simple-responder
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: make simple-responder-windows
|
||||||
|
|
||||||
|
- name: Save Simple Responder
|
||||||
|
# nothing new to save ... skip this step
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
id: save-simple-responder
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
|
- name: Test all
|
||||||
|
shell: bash
|
||||||
|
run: make test-all
|
||||||
@@ -1,13 +1,27 @@
|
|||||||
# This workflow will build a golang project
|
name: Linux CI
|
||||||
|
|
||||||
name: CI
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ "main" ]
|
branches: [ "main" ]
|
||||||
|
# only run when backend source changes
|
||||||
|
# cmd/ is excluded because it contains utilities without tests
|
||||||
|
paths:
|
||||||
|
- '**/*.go'
|
||||||
|
- '!cmd/**'
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- 'Makefile'
|
||||||
|
- '.github/workflows/go-ci.yml'
|
||||||
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ "main" ]
|
branches: [ "main" ]
|
||||||
|
paths:
|
||||||
|
- '**/*.go'
|
||||||
|
- '!cmd/**'
|
||||||
|
- 'go.mod'
|
||||||
|
- 'go.sum'
|
||||||
|
- 'Makefile'
|
||||||
|
- '.github/workflows/go-ci.yml'
|
||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
@@ -24,9 +38,33 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
go-version: '1.23'
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# Only run in this linux based runner
|
||||||
|
- name: Check Formatting
|
||||||
|
run: |
|
||||||
|
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||||
|
gofmt -l . | grep -v 'event/.*_test.go'
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
# cache simple-responder to save the build time
|
||||||
|
- name: Restore Simple Responder
|
||||||
|
id: restore-simple-responder
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
# necessary for testing proxy/Process swapping
|
# necessary for testing proxy/Process swapping
|
||||||
- name: Create simple-responder
|
- name: Create simple-responder
|
||||||
run: make simple-responder
|
run: make simple-responder
|
||||||
|
|
||||||
|
- name: Save Simple Responder
|
||||||
|
# nothing new to save ... skip this step
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
id: save-simple-responder
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
- name: Test all
|
- name: Test all
|
||||||
run: make test-all
|
run: make test-all
|
||||||
@@ -3,10 +3,14 @@ name: goreleaser
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- "*"
|
||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: "Tag version to release (e.g. v144)"
|
||||||
|
required: true
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
@@ -15,22 +19,56 @@ jobs:
|
|||||||
goreleaser:
|
goreleaser:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
-
|
- name: Checkout
|
||||||
name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
-
|
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||||
name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
-
|
- name: Set up Node.js
|
||||||
name: Run GoReleaser
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "24"
|
||||||
|
- name: Install dependencies and build UI
|
||||||
|
run: |
|
||||||
|
cd ui-svelte
|
||||||
|
npm ci
|
||||||
|
npm run build
|
||||||
|
|
||||||
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v6
|
uses: goreleaser/goreleaser-action@v6
|
||||||
with:
|
with:
|
||||||
# either 'goreleaser' (default) or 'goreleaser-pro'
|
# either 'goreleaser' (default) or 'goreleaser-pro'
|
||||||
distribution: goreleaser
|
distribution: goreleaser
|
||||||
# 'latest', 'nightly', or a semver
|
# 'latest', 'nightly', or a semver
|
||||||
version: '~> v2'
|
version: "~> v2"
|
||||||
args: release --clean
|
args: release --clean
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
trigger-tap-update:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: goreleaser
|
||||||
|
steps:
|
||||||
|
- name: "Resolve tag to dispatch"
|
||||||
|
id: tag
|
||||||
|
run: |
|
||||||
|
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||||
|
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: "Trigger tap repository update"
|
||||||
|
uses: peter-evans/repository-dispatch@v2
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.TAP_REPO_PAT }}
|
||||||
|
repository: mostlygeek/homebrew-llama-swap
|
||||||
|
event-type: new-release
|
||||||
|
client-payload: |
|
||||||
|
{
|
||||||
|
"release": {
|
||||||
|
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
name: UI Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
paths:
|
||||||
|
- 'ui-svelte/**'
|
||||||
|
- '.github/workflows/ui-tests.yml'
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
paths:
|
||||||
|
- 'ui-svelte/**'
|
||||||
|
- '.github/workflows/ui-tests.yml'
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
run-tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: ui-svelte
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '24'
|
||||||
|
cache: 'npm'
|
||||||
|
cache-dependency-path: ui-svelte/package-lock.json
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Type check
|
||||||
|
run: npm run check
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: npm test
|
||||||
@@ -4,3 +4,4 @@ build/
|
|||||||
dist/
|
dist/
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
.dev/
|
||||||
|
|||||||
@@ -17,14 +17,16 @@ builds:
|
|||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
|
|
||||||
# use zip format for windows
|
|
||||||
archives:
|
archives:
|
||||||
- id: default
|
- id: default
|
||||||
format: tar.gz
|
formats:
|
||||||
|
- tar.gz
|
||||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
builds_info:
|
builds_info:
|
||||||
group: root
|
group: root
|
||||||
owner: root
|
owner: root
|
||||||
format_overrides:
|
format_overrides:
|
||||||
|
# use zip format for windows
|
||||||
- goos: windows
|
- goos: windows
|
||||||
format: zip
|
formats:
|
||||||
|
- zip
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
## Project Description:
|
||||||
|
|
||||||
|
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||||
|
|
||||||
|
## Tech stack
|
||||||
|
|
||||||
|
- golang
|
||||||
|
- typescript, vite and react for UI (located in ui/)
|
||||||
|
|
||||||
|
## Workflow Tasks
|
||||||
|
|
||||||
|
- when summarizing changes only include details that require further action
|
||||||
|
- just say "Done." when there is no further action
|
||||||
|
- use `gh` to create PRs and load issues
|
||||||
|
- do include Co-Authored-By or created by when committing changes or creating PRs
|
||||||
|
- keep PR descriptions short and focused on changes.
|
||||||
|
- never include a test plan
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||||
|
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||||
|
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||||
|
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||||
|
|
||||||
|
### Commit message example format:
|
||||||
|
|
||||||
|
```
|
||||||
|
proxy: add new feature
|
||||||
|
|
||||||
|
Add new feature that implements functionality X and Y.
|
||||||
|
|
||||||
|
- key change 1
|
||||||
|
- key change 2
|
||||||
|
- key change 3
|
||||||
|
|
||||||
|
fixes #123
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code Reviews
|
||||||
|
|
||||||
|
- use three levels High, Medium, Low severity
|
||||||
|
- label each discovered issue with a label like H1, M2, L3 respectively
|
||||||
|
- High severity are must fix issues (security, race conditions, critical bugs)
|
||||||
|
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
|
||||||
|
- Low severity are nice to have changes and nits
|
||||||
|
- Include a suggestion with each discovered item
|
||||||
|
- Limit your code review to three items with the highest priority first
|
||||||
|
- Double check your discovered items and recommended remediations
|
||||||
@@ -19,32 +19,54 @@ all: mac linux simple-responder
|
|||||||
clean:
|
clean:
|
||||||
rm -rf $(BUILD_DIR)
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
test:
|
proxy/ui_dist/placeholder.txt:
|
||||||
go test -short -v -count=1 ./proxy
|
mkdir -p proxy/ui_dist
|
||||||
|
touch $@
|
||||||
|
|
||||||
test-all:
|
# use cached test results while developing
|
||||||
go test -v -count=1 ./proxy
|
test-dev: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -short ./proxy/...
|
||||||
|
staticcheck ./proxy/... || true
|
||||||
|
|
||||||
|
test: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -short -count=1 ./proxy/...
|
||||||
|
|
||||||
|
# for CI - full test (takes longer)
|
||||||
|
test-all: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -race -count=1 ./proxy/...
|
||||||
|
|
||||||
|
ui/node_modules:
|
||||||
|
cd ui-svelte && npm install
|
||||||
|
|
||||||
|
# build react UI
|
||||||
|
ui: ui/node_modules
|
||||||
|
cd ui-svelte && npm run build
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
mac:
|
mac: ui
|
||||||
@echo "Building Mac binary..."
|
@echo "Building Mac binary..."
|
||||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||||
|
|
||||||
# Build Linux binary
|
# Build Linux binary
|
||||||
linux:
|
linux: ui
|
||||||
@echo "Building Linux binary..."
|
@echo "Building Linux binary..."
|
||||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||||
|
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||||
|
|
||||||
# Build Windows binary
|
# Build Windows binary
|
||||||
windows:
|
windows: ui
|
||||||
@echo "Building Windows binary..."
|
@echo "Building Windows binary..."
|
||||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||||
|
|
||||||
# for testing proxy.Process
|
# for testing proxy.Process
|
||||||
simple-responder:
|
simple-responder:
|
||||||
@echo "Building simple responder"
|
@echo "Building simple responder"
|
||||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
||||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 cmd/simple-responder/simple-responder.go
|
||||||
|
|
||||||
|
simple-responder-windows:
|
||||||
|
@echo "Building simple responder for windows"
|
||||||
|
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe cmd/simple-responder/simple-responder.go
|
||||||
|
|
||||||
# Ensure build directory exists
|
# Ensure build directory exists
|
||||||
$(BUILD_DIR):
|
$(BUILD_DIR):
|
||||||
@@ -64,5 +86,11 @@ release:
|
|||||||
echo "tagging new version: $$new_tag"; \
|
echo "tagging new version: $$new_tag"; \
|
||||||
git tag "$$new_tag";
|
git tag "$$new_tag";
|
||||||
|
|
||||||
|
GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
|
||||||
|
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
|
||||||
|
wol-proxy: $(BUILD_DIR)
|
||||||
|
@echo "Building wol-proxy"
|
||||||
|
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean mac linux windows simple-responder
|
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
|
||||||
|
|||||||
@@ -1,288 +1,235 @@
|
|||||||

|

|
||||||

|

|
||||||

|

|
||||||

|

|
||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
|
||||||
|
|
||||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
||||||
|
|
||||||
## Features:
|
## Features:
|
||||||
|
|
||||||
- ✅ Easy to deploy: single binary with no dependencies
|
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
||||||
- ✅ Easy to config: single yaml file
|
|
||||||
- ✅ On-demand model switching
|
- ✅ On-demand model switching
|
||||||
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
|
||||||
|
- future proof, upgrade your inference servers at any time.
|
||||||
- ✅ OpenAI API supported endpoints:
|
- ✅ OpenAI API supported endpoints:
|
||||||
- `v1/completions`
|
- `v1/completions`
|
||||||
- `v1/chat/completions`
|
- `v1/chat/completions`
|
||||||
|
- `v1/responses`
|
||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
- `v1/rerank`
|
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||||
- ✅ llama-swap custom API endpoints
|
- `v1/audio/voices`
|
||||||
- `/log` - remote log monitoring
|
- `v1/images/generations`
|
||||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- `v1/images/edits`
|
||||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
- ✅ Anthropic API supported endpoints:
|
||||||
|
- `v1/messages`
|
||||||
|
- `v1/messages/count_tokens`
|
||||||
|
- ✅ llama-server (llama.cpp) supported endpoints
|
||||||
|
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||||
|
- `/infill` - for code infilling
|
||||||
|
- `/completion` - for completion endpoint
|
||||||
|
- ✅ llama-swap API
|
||||||
|
- `/ui` - web UI
|
||||||
|
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
|
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||||
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
- `/log` - remote log monitoring
|
||||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
- `/health` - just returns "OK"
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||||
- ✅ Docker and Podman support
|
- ✅ Customizable
|
||||||
- ✅ Full control over server settings per model
|
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||||
|
- Automatic unloading of models after timeout by setting a `ttl`
|
||||||
|
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
||||||
|
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||||
|
|
||||||
## How does llama-swap work?
|
### Web UI
|
||||||
|
|
||||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
llama-swap includes a real time web interface for monitoring logs and controlling models:
|
||||||
|
|
||||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
|
||||||
|
|
||||||
## config.yaml
|
The Activity Page shows recent requests:
|
||||||
|
|
||||||
llama-swap's configuration is purposefully simple.
|
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||||
|
|
||||||
```yaml
|
## Installation
|
||||||
models:
|
|
||||||
"qwen2.5":
|
|
||||||
proxy: "http://127.0.0.1:9999"
|
|
||||||
cmd: >
|
|
||||||
/app/llama-server
|
|
||||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
|
||||||
--port 9999
|
|
||||||
|
|
||||||
"smollm2":
|
llama-swap can be installed in multiple ways
|
||||||
proxy: "http://127.0.0.1:9999"
|
|
||||||
cmd: >
|
|
||||||
/app/llama-server
|
|
||||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
|
||||||
--port 9999
|
|
||||||
```
|
|
||||||
|
|
||||||
<details>
|
1. Docker
|
||||||
<summary>But also very powerful ...</summary>
|
2. Homebrew (OSX and Linux)
|
||||||
|
3. WinGet
|
||||||
|
4. From release binaries
|
||||||
|
5. From source
|
||||||
|
|
||||||
```yaml
|
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
|
||||||
# Default (and minimum) is 15 seconds
|
|
||||||
healthCheckTimeout: 60
|
|
||||||
|
|
||||||
# Valid log levels: debug, info (default), warn, error
|
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
|
||||||
logLevel: info
|
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
|
||||||
|
|
||||||
# Automatic Port Values
|
|
||||||
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
|
|
||||||
# when you use ${PORT} you can omit a custom model.proxy value, as it will
|
|
||||||
# default to http://localhost:${PORT}
|
|
||||||
|
|
||||||
# override the default port (5800) for automatic port values
|
|
||||||
startPort: 10001
|
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
|
||||||
models:
|
|
||||||
"llama":
|
|
||||||
# multiline for readability
|
|
||||||
cmd: >
|
|
||||||
llama-server --port 8999
|
|
||||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
|
||||||
|
|
||||||
# environment variables to pass to the command
|
|
||||||
env:
|
|
||||||
- "CUDA_VISIBLE_DEVICES=0"
|
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
|
||||||
# can be omitted if you use an automatic ${PORT} in cmd
|
|
||||||
proxy: http://127.0.0.1:8999
|
|
||||||
|
|
||||||
# aliases names to use this model for
|
|
||||||
aliases:
|
|
||||||
- "gpt-4o-mini"
|
|
||||||
- "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
# check this path for an HTTP 200 OK before serving requests
|
|
||||||
# default: /health to match llama.cpp
|
|
||||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
|
||||||
# until the model is ready
|
|
||||||
checkEndpoint: /custom-endpoint
|
|
||||||
|
|
||||||
# automatically unload the model after this many seconds
|
|
||||||
# ttl values must be a value greater than 0
|
|
||||||
# default: 0 = never unload model
|
|
||||||
ttl: 60
|
|
||||||
|
|
||||||
# `useModelName` overrides the model name in the request
|
|
||||||
# and sends a specific name to the upstream server
|
|
||||||
useModelName: "qwen:qwq"
|
|
||||||
|
|
||||||
# unlisted models do not show up in /v1/models or /upstream lists
|
|
||||||
# but they can still be requested as normal
|
|
||||||
"qwen-unlisted":
|
|
||||||
unlisted: true
|
|
||||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
|
||||||
"docker-llama":
|
|
||||||
proxy: "http://127.0.0.1:${PORT}"
|
|
||||||
cmd: >
|
|
||||||
docker run --name dockertest
|
|
||||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
|
||||||
|
|
||||||
# Groups provide advanced controls over model swapping behaviour. Using groups
|
|
||||||
# some models can be kept loaded indefinitely, while others are swapped out.
|
|
||||||
#
|
|
||||||
# Tips:
|
|
||||||
#
|
|
||||||
# - models must be defined above in the Models section
|
|
||||||
# - a model can only be a member of one group
|
|
||||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
|
||||||
# - see issue #109 for details
|
|
||||||
#
|
|
||||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
|
||||||
groups:
|
|
||||||
# group1 is the default behaviour of llama-swap where only one model is allowed
|
|
||||||
# to run a time across the whole llama-swap instance
|
|
||||||
"group1":
|
|
||||||
# swap controls the model swapping behaviour in within the group
|
|
||||||
# - true : only one model is allowed to run at a time
|
|
||||||
# - false: all models can run together, no swapping
|
|
||||||
swap: true
|
|
||||||
|
|
||||||
# exclusive controls how the group affects other groups
|
|
||||||
# - true: causes all other groups to unload their models when this group runs a model
|
|
||||||
# - false: does not affect other groups
|
|
||||||
exclusive: true
|
|
||||||
|
|
||||||
# members references the models defined above
|
|
||||||
members:
|
|
||||||
- "llama"
|
|
||||||
- "qwen-unlisted"
|
|
||||||
|
|
||||||
# models in this group are never unloaded
|
|
||||||
"group2":
|
|
||||||
swap: false
|
|
||||||
exclusive: false
|
|
||||||
members:
|
|
||||||
- "docker-llama"
|
|
||||||
# (not defined above, here for example)
|
|
||||||
- "modelA"
|
|
||||||
- "modelB"
|
|
||||||
|
|
||||||
"forever":
|
|
||||||
# setting persistent to true causes the group to never be affected by the swapping behaviour of
|
|
||||||
# other groups. It is a shortcut to keeping some models always loaded.
|
|
||||||
persistent: true
|
|
||||||
|
|
||||||
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
|
|
||||||
swap: false
|
|
||||||
exclusive: false
|
|
||||||
members:
|
|
||||||
- "forever-modelA"
|
|
||||||
- "forever-modelB"
|
|
||||||
- "forever-modelc"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Use Case Examples
|
|
||||||
|
|
||||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
|
||||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
|
||||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
|
||||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
|
||||||
|
|
||||||
Docker is the quickest way to try out llama-swap:
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# use CPU inference
|
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
|
||||||
|
|
||||||
|
# run with a custom configuration and models directory
|
||||||
# qwen2.5 0.5B
|
|
||||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer no-key" \
|
|
||||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
|
||||||
jq -r '.choices[0].message.content'
|
|
||||||
|
|
||||||
|
|
||||||
# SmolLM2 135M
|
|
||||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer no-key" \
|
|
||||||
-d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
|
||||||
jq -r '.choices[0].message.content'
|
|
||||||
```
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>Docker images are nightly ...</summary>
|
|
||||||
|
|
||||||
They include:
|
|
||||||
|
|
||||||
- `ghcr.io/mostlygeek/llama-swap:cpu`
|
|
||||||
- `ghcr.io/mostlygeek/llama-swap:cuda`
|
|
||||||
- `ghcr.io/mostlygeek/llama-swap:intel`
|
|
||||||
- `ghcr.io/mostlygeek/llama-swap:vulkan`
|
|
||||||
- ROCm disabled until fixed in llama.cpp container
|
|
||||||
|
|
||||||
Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716`
|
|
||||||
|
|
||||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||||
-v /path/to/models:/models \
|
-v /path/to/models:/models \
|
||||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||||
ghcr.io/mostlygeek/llama-swap:cuda
|
ghcr.io/mostlygeek/llama-swap:cuda
|
||||||
|
|
||||||
|
# configuration hot reload supported with a
|
||||||
|
# directory volume mount
|
||||||
|
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||||
|
-v /path/to/models:/models \
|
||||||
|
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||||
|
-v /path/to/config:/config \
|
||||||
|
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>
|
||||||
|
more examples
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# pull latest images per platform
|
||||||
|
docker pull ghcr.io/mostlygeek/llama-swap:cpu
|
||||||
|
docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||||
|
docker pull ghcr.io/mostlygeek/llama-swap:vulkan
|
||||||
|
docker pull ghcr.io/mostlygeek/llama-swap:intel
|
||||||
|
docker pull ghcr.io/mostlygeek/llama-swap:musa
|
||||||
|
|
||||||
|
# tagged llama-swap, platform and llama-server version images
|
||||||
|
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
||||||
|
|
||||||
|
# non-root cuda
|
||||||
|
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
### Homebrew Install (macOS/Linux)
|
||||||
|
|
||||||
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
```shell
|
||||||
|
brew tap mostlygeek/llama-swap
|
||||||
|
brew install llama-swap
|
||||||
|
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||||
|
```
|
||||||
|
|
||||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
### WinGet Install (Windows)
|
||||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
|
||||||
1. Run the binary with `llama-swap --config path/to/config.yaml`.
|
> [!NOTE]
|
||||||
Available flags:
|
> WinGet is maintained by community contributor [Dvd-Znf](https://github.com/Dvd-Znf) ([#327](https://github.com/mostlygeek/llama-swap/issues/327)). It is not an official part of llama-swap.
|
||||||
- `--config`: Path to the configuration file (default: `config.yaml`).
|
|
||||||
- `--listen`: Address and port to listen on (default: `:8080`).
|
```shell
|
||||||
- `--version`: Show version information and exit.
|
# install
|
||||||
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
|
C:\> winget install llama-swap
|
||||||
|
|
||||||
|
# upgrade
|
||||||
|
C:\> winget upgrade llama-swap
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-built Binaries
|
||||||
|
|
||||||
|
Binaries are available on the [release](https://github.com/mostlygeek/llama-swap/releases) page for Linux, Mac, Windows and FreeBSD.
|
||||||
|
|
||||||
### Building from source
|
### Building from source
|
||||||
|
|
||||||
1. Install golang for your system
|
1. Building requires Go and Node.js (for UI).
|
||||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
1. `git clone https://github.com/mostlygeek/llama-swap.git`
|
||||||
1. `make clean all`
|
1. `make clean all`
|
||||||
1. Binaries will be in `build/` subdirectory
|
1. look in the `build/` subdirectory for the llama-swap binary
|
||||||
|
|
||||||
## Monitoring Logs
|
## Configuration
|
||||||
|
|
||||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
```yaml
|
||||||
|
# minimum viable config.yaml
|
||||||
|
|
||||||
Of course, CLI access is also supported:
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
|
||||||
|
```
|
||||||
|
|
||||||
```shell
|
That's all you need to get started:
|
||||||
|
|
||||||
|
1. `models` - holds all model configurations
|
||||||
|
2. `model1` - the ID used in API calls
|
||||||
|
3. `cmd` - the command to run to start the server.
|
||||||
|
4. `${PORT}` - an automatically assigned port number
|
||||||
|
|
||||||
|
Almost all configuration settings are optional and can be added one step at a time:
|
||||||
|
|
||||||
|
- Advanced features
|
||||||
|
- `groups` to run multiple models at once
|
||||||
|
- `hooks` to run things on startup
|
||||||
|
- `macros` reusable snippets
|
||||||
|
- Model customization
|
||||||
|
- `ttl` to automatically unload models
|
||||||
|
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||||
|
- `env` to pass custom environment variables to inference servers
|
||||||
|
- `cmdStop` gracefully stop Docker/Podman containers
|
||||||
|
- `useModelName` to override model names sent to upstream servers
|
||||||
|
- `${PORT}` automatic port variables for dynamic port assignment
|
||||||
|
- `filters` rewrite parts of requests before sending to the upstream server
|
||||||
|
|
||||||
|
See the [configuration documentation](docs/configuration.md) for all options.
|
||||||
|
|
||||||
|
## How does llama-swap work?
|
||||||
|
|
||||||
|
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
|
||||||
|
|
||||||
|
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||||
|
|
||||||
|
## Reverse Proxy Configuration (nginx)
|
||||||
|
|
||||||
|
If you deploy llama-swap behind nginx, disable response buffering for streaming endpoints. By default, nginx buffers responses which breaks Server‑Sent Events (SSE) and streaming chat completion. ([#236](https://github.com/mostlygeek/llama-swap/issues/236))
|
||||||
|
|
||||||
|
Recommended nginx configuration snippets:
|
||||||
|
|
||||||
|
```nginx
|
||||||
|
# SSE for UI events/logs
|
||||||
|
location /api/events {
|
||||||
|
proxy_pass http://your-llama-swap-backend;
|
||||||
|
proxy_buffering off;
|
||||||
|
proxy_cache off;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Streaming chat completions (stream=true)
|
||||||
|
location /v1/chat/completions {
|
||||||
|
proxy_pass http://your-llama-swap-backend;
|
||||||
|
proxy_buffering off;
|
||||||
|
proxy_cache off;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. However, explicitly disabling `proxy_buffering` at your reverse proxy is still recommended for reliable streaming behavior.
|
||||||
|
|
||||||
|
## Monitoring Logs on the CLI
|
||||||
|
|
||||||
|
```sh
|
||||||
# sends up to the last 10KB of logs
|
# sends up to the last 10KB of logs
|
||||||
curl http://host/logs'
|
$ curl http://host/logs
|
||||||
|
|
||||||
# streams combined logs
|
# streams combined logs
|
||||||
curl -Ns 'http://host/logs/stream'
|
curl -Ns http://host/logs/stream
|
||||||
|
|
||||||
# just llama-swap's logs
|
# stream llama-swap's proxy status logs
|
||||||
curl -Ns 'http://host/logs/stream/proxy'
|
curl -Ns http://host/logs/stream/proxy
|
||||||
|
|
||||||
# just upstream's logs
|
# stream logs from upstream processes that llama-swap loads
|
||||||
curl -Ns 'http://host/logs/stream/upstream'
|
curl -Ns http://host/logs/stream/upstream
|
||||||
|
|
||||||
|
# stream logs only from a specific model
|
||||||
|
curl -Ns http://host/logs/stream/{model_id}
|
||||||
|
|
||||||
# stream and filter logs with linux pipes
|
# stream and filter logs with linux pipes
|
||||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||||
|
|
||||||
# skips history and just streams new log entries
|
# appending ?no-history will disable sending buffered history first
|
||||||
curl -Ns 'http://host/logs/stream?no-history'
|
curl -Ns 'http://host/logs/stream?no-history'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -290,34 +237,11 @@ curl -Ns 'http://host/logs/stream?no-history'
|
|||||||
|
|
||||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||||
|
|
||||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals for proper shutdown.
|
||||||
|
|
||||||
## Systemd Unit Files
|
|
||||||
|
|
||||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
|
||||||
|
|
||||||
`/etc/systemd/system/llama-swap.service`
|
|
||||||
|
|
||||||
```
|
|
||||||
[Unit]
|
|
||||||
Description=llama-swap
|
|
||||||
After=network.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
User=nobody
|
|
||||||
|
|
||||||
# set this to match your environment
|
|
||||||
ExecStart=/path/to/llama-swap --config /path/to/llama-swap.config.yml
|
|
||||||
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=3
|
|
||||||
StartLimitBurst=3
|
|
||||||
StartLimitInterval=30
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
|
||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> ⭐️ Star this project to help others discover it!
|
||||||
|
|
||||||
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
# Replace ring.Ring with Efficient Circular Byte Buffer
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
|
||||||
|
|
||||||
|
## Current Issues
|
||||||
|
|
||||||
|
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
|
||||||
|
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
|
||||||
|
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
|
||||||
|
4. Linked list structure has poor cache locality and pointer overhead
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### New CircularBuffer Type
|
||||||
|
|
||||||
|
Create a simple circular byte buffer with:
|
||||||
|
- Single pre-allocated `[]byte` of fixed capacity (10KB)
|
||||||
|
- `head` and `size` integers to track write position and data length
|
||||||
|
- No per-write allocations
|
||||||
|
|
||||||
|
### API Requirements
|
||||||
|
|
||||||
|
The new buffer must support:
|
||||||
|
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
|
||||||
|
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
|
||||||
|
|
||||||
|
### Implementation Details
|
||||||
|
|
||||||
|
```go
|
||||||
|
type circularBuffer struct {
|
||||||
|
data []byte // pre-allocated capacity
|
||||||
|
head int // next write position
|
||||||
|
size int // current number of bytes stored (0 to cap)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Write logic:**
|
||||||
|
- If `len(p) >= capacity`: just keep the last `capacity` bytes
|
||||||
|
- Otherwise: write bytes at `head`, wrapping around if needed
|
||||||
|
- Update `head` and `size` accordingly
|
||||||
|
- Data is copied into the internal buffer (not stored by reference)
|
||||||
|
|
||||||
|
**GetHistory logic:**
|
||||||
|
- Calculate start position: `(head - size + cap) % cap`
|
||||||
|
- If not wrapped: single slice copy
|
||||||
|
- If wrapped: two copies (end of buffer + beginning)
|
||||||
|
- Returns a **new slice** (copy), not a view into internal buffer
|
||||||
|
|
||||||
|
### Immutability Guarantees (must preserve)
|
||||||
|
|
||||||
|
Per existing tests:
|
||||||
|
1. Modifying input `[]byte` after `Write()` must not affect stored data
|
||||||
|
2. `GetHistory()` returns independent copy - modifications don't affect buffer
|
||||||
|
|
||||||
|
## Files to Modify
|
||||||
|
|
||||||
|
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
Existing tests in `logMonitor_test.go` should continue to pass:
|
||||||
|
- `TestLogMonitor` - Basic write/read and subscriber notification
|
||||||
|
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
|
||||||
|
- `TestWrite_LogTimeFormat` - Timestamp formatting
|
||||||
|
|
||||||
|
Add new tests:
|
||||||
|
- Test buffer wrap-around behavior
|
||||||
|
- Test large writes that exceed buffer capacity
|
||||||
|
- Test exact capacity boundary conditions
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Create `circularBuffer` struct in `logMonitor.go`
|
||||||
|
- [ ] Implement `Write()` method for circular buffer
|
||||||
|
- [ ] Implement `GetHistory()` method for circular buffer
|
||||||
|
- [ ] Update `LogMonitor` struct to use new buffer
|
||||||
|
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
|
||||||
|
- [ ] Update `LogMonitor.Write()` to use new buffer
|
||||||
|
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
|
||||||
|
- [ ] Remove `"container/ring"` import
|
||||||
|
- [ ] Run `make test-dev` to verify existing tests pass
|
||||||
|
- [ ] Add wrap-around test case
|
||||||
|
- [ ] Run `make test-all` for final validation
|
||||||
@@ -0,0 +1,292 @@
|
|||||||
|
# Add Model Metadata Support with Typed Macros
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
|
||||||
|
|
||||||
|
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### 1. Enhanced Macro System
|
||||||
|
|
||||||
|
**Current State:**
|
||||||
|
|
||||||
|
- Macros are defined as `map[string]string` at both global and model levels
|
||||||
|
- Only string substitution is supported
|
||||||
|
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Change `MacroList` type from `map[string]string` to `map[string]any`
|
||||||
|
- Support scalar types: `string`, `int`, `float64`, `bool`
|
||||||
|
- Implement type-preserving macro substitution:
|
||||||
|
- Direct macro usage (`key: ${macro}`) preserves the macro's type
|
||||||
|
- Interpolated usage (`key: "text ${macro}"`) converts to string
|
||||||
|
- Add validation to ensure macro values are scalar types only
|
||||||
|
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
|
||||||
|
|
||||||
|
**Implementation Details:**
|
||||||
|
|
||||||
|
- Create a generic helper function to perform macro substitution that:
|
||||||
|
- Takes a value of type `any`
|
||||||
|
- Recursively processes maps, slices, and scalar values
|
||||||
|
- Replaces `${macro_name}` patterns with macro values
|
||||||
|
- Preserves types for direct substitution
|
||||||
|
- Converts to strings for interpolated substitution
|
||||||
|
- Update `validateMacro()` function to accept `any` type and validate scalar types
|
||||||
|
- Maintain backward compatibility with existing string-only macros
|
||||||
|
|
||||||
|
### 2. Metadata Field in ModelConfig
|
||||||
|
|
||||||
|
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Add `Metadata map[string]any` field to `ModelConfig` struct
|
||||||
|
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
|
||||||
|
- Apply macro substitution to metadata values during config loading
|
||||||
|
|
||||||
|
**Schema Requirements:**
|
||||||
|
|
||||||
|
- Metadata is optional (default: empty/nil map)
|
||||||
|
- Supports nested structures (objects within objects, arrays, etc.)
|
||||||
|
- All string values within metadata undergo macro substitution
|
||||||
|
- Type preservation rules apply as described above
|
||||||
|
|
||||||
|
### 3. Macro Substitution in Metadata
|
||||||
|
|
||||||
|
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
|
||||||
|
|
||||||
|
**Process Flow:**
|
||||||
|
|
||||||
|
1. After loading YAML configuration
|
||||||
|
2. After model-level and global macro merging
|
||||||
|
3. Apply macro substitution to `ModelConfig.Metadata` field
|
||||||
|
4. Use the same merged macros available to `cmd`, `proxy`, etc.
|
||||||
|
5. Process recursively through all nested structures
|
||||||
|
|
||||||
|
**Substitution Rules:**
|
||||||
|
|
||||||
|
- `port: ${PORT}` → keeps integer type from PORT macro
|
||||||
|
- `temperature: ${temp}` → keeps float type from temp macro
|
||||||
|
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
|
||||||
|
- Arrays and nested objects are processed recursively
|
||||||
|
- Unknown macros should cause configuration load error (consistent with existing behavior)
|
||||||
|
|
||||||
|
### 4. API Response Updates
|
||||||
|
|
||||||
|
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
|
||||||
|
|
||||||
|
**Current Behavior:**
|
||||||
|
|
||||||
|
- Returns model records with: `id`, `object`, `created`, `owned_by`
|
||||||
|
- Optionally includes: `name`, `description`
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Add metadata to each model record under the key `llamaswap_meta`
|
||||||
|
- Only include `llamaswap_meta` if metadata is non-empty
|
||||||
|
- Preserve all types when marshaling to JSON
|
||||||
|
- Maintain existing sorting by model ID
|
||||||
|
|
||||||
|
**Example Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "llama",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1234567890,
|
||||||
|
"owned_by": "llama-swap",
|
||||||
|
"name": "llama 3.1 8B",
|
||||||
|
"description": "A small but capable model",
|
||||||
|
"llamaswap_meta": {
|
||||||
|
"port": 10001,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"note": "The llama is running on port 10001 temp=0.7, context=16384",
|
||||||
|
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
|
||||||
|
"an_obj": {
|
||||||
|
"a": "1",
|
||||||
|
"b": 2,
|
||||||
|
"c": [0.7, false, "model: llama"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Validation and Error Handling
|
||||||
|
|
||||||
|
**Macro Validation:**
|
||||||
|
|
||||||
|
- Extend `validateMacro()` to accept values of type `any`
|
||||||
|
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
|
||||||
|
- Reject complex types (maps, slices, structs) as macro values
|
||||||
|
- Maintain existing validation for macro names and lengths
|
||||||
|
|
||||||
|
**Configuration Loading:**
|
||||||
|
|
||||||
|
- Fail fast if unknown macros are found in metadata
|
||||||
|
- Provide clear error messages indicating which model and field contains errors
|
||||||
|
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
### Test 1: Model-Level Macros with Different Types
|
||||||
|
|
||||||
|
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Define model with macros of each scalar type
|
||||||
|
- Verify metadata correctly substitutes and preserves types
|
||||||
|
- Test direct substitution (`port: ${PORT}`)
|
||||||
|
- Test string interpolation (`note: "Port is ${PORT}"`)
|
||||||
|
- Verify nested objects and arrays work correctly
|
||||||
|
|
||||||
|
### Test 2: Global and Model Macro Precedence
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Define same macro at global and model level with different types
|
||||||
|
- Verify model-level macro takes precedence
|
||||||
|
- Test metadata uses correct macro value
|
||||||
|
- Verify type is preserved from the winning macro
|
||||||
|
|
||||||
|
### Test 3: Macro Validation
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Test that complex types (maps, arrays) are rejected as macro values
|
||||||
|
- Verify error message includes: macro name and type that was rejected
|
||||||
|
- Test that scalar types (string, int, float, bool) are accepted
|
||||||
|
- Each type should load without error
|
||||||
|
- Test macro name validation still works with `any` types
|
||||||
|
- Invalid characters, reserved names, length limits should still be enforced
|
||||||
|
|
||||||
|
### Test 4: Metadata in API Response
|
||||||
|
|
||||||
|
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||||
|
|
||||||
|
**Existing Test:** `TestProxyManager_ListModelsHandler`
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Model with metadata → verify `llamaswap_meta` key appears
|
||||||
|
- Model without metadata → verify `llamaswap_meta` key is absent
|
||||||
|
- Verify all types are correctly marshaled to JSON
|
||||||
|
- Verify nested structures are preserved
|
||||||
|
- Verify macro substitution has occurred before serialization
|
||||||
|
|
||||||
|
### Test 5: Unknown Macros in Metadata
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Use undefined macro in metadata
|
||||||
|
- Verify configuration loading fails with clear error
|
||||||
|
- Error should indicate model name and that macro is undefined
|
||||||
|
|
||||||
|
### Test 6: Recursive Substitution
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Metadata with deeply nested structures
|
||||||
|
- Arrays containing objects with macros
|
||||||
|
- Objects containing arrays with macros
|
||||||
|
- Mixed string interpolation and direct substitution at various nesting levels
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
### Configuration Schema Changes
|
||||||
|
|
||||||
|
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
|
||||||
|
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
|
||||||
|
- [x] Update `validateMacro()` function signature to accept `any` type for values
|
||||||
|
- [x] Add validation logic to ensure macro values are scalar types only
|
||||||
|
|
||||||
|
### Macro Substitution Logic
|
||||||
|
|
||||||
|
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
|
||||||
|
- [x] Implement type-preserving direct substitution logic
|
||||||
|
- [x] Implement string interpolation with type conversion
|
||||||
|
- [x] Handle maps: recursively process all values
|
||||||
|
- [x] Handle slices: recursively process all elements
|
||||||
|
- [x] Handle scalar types: perform string-based macro substitution if value is string
|
||||||
|
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
|
||||||
|
- [x] Update existing macro substitution calls to use merged macros with correct types
|
||||||
|
|
||||||
|
### API Response Changes
|
||||||
|
|
||||||
|
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
|
||||||
|
- [x] Add `llamaswap_meta` field to model records when metadata exists
|
||||||
|
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
|
||||||
|
- [x] Verify JSON marshaling preserves all types correctly
|
||||||
|
|
||||||
|
### Testing - Config Package
|
||||||
|
|
||||||
|
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
### Testing - Model Config Package
|
||||||
|
|
||||||
|
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
|
||||||
|
- [x] Test metadata with various scalar types
|
||||||
|
- [x] Test metadata with nested objects and arrays
|
||||||
|
|
||||||
|
### Testing - Proxy Manager
|
||||||
|
|
||||||
|
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||||
|
- [x] Add test case for model with metadata
|
||||||
|
- [x] Add test case for model without metadata
|
||||||
|
- [x] Verify `llamaswap_meta` key presence/absence
|
||||||
|
- [x] Verify type preservation in JSON output
|
||||||
|
- [x] Verify macro substitution has occurred
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
|
||||||
|
- [x] No additional documentation needed per project instructions
|
||||||
|
|
||||||
|
## Known Issues and Considerations
|
||||||
|
|
||||||
|
### Inconsistencies
|
||||||
|
|
||||||
|
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
|
||||||
|
|
||||||
|
### Design Decisions
|
||||||
|
|
||||||
|
1. **Why `llamaswap_meta` instead of merging into record?**
|
||||||
|
|
||||||
|
- Avoids potential collisions with OpenAI API standard fields
|
||||||
|
- Makes it clear this is llama-swap specific metadata
|
||||||
|
- Easier for clients to distinguish standard vs. custom fields
|
||||||
|
|
||||||
|
2. **Why support nested structures?**
|
||||||
|
|
||||||
|
- Provides maximum flexibility for users
|
||||||
|
- Aligns with the schemaless design principle
|
||||||
|
- Example config already demonstrates this capability
|
||||||
|
|
||||||
|
3. **Why validate macro types?**
|
||||||
|
- Prevents confusing behavior (e.g., substituting a map)
|
||||||
|
- Makes configuration errors explicit at load time
|
||||||
|
- Simpler implementation and testing
|
||||||
@@ -0,0 +1,397 @@
|
|||||||
|
# Improve macro-in-macro support
|
||||||
|
|
||||||
|
**Status: COMPLETED ✅**
|
||||||
|
|
||||||
|
## Title
|
||||||
|
|
||||||
|
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
|
||||||
|
|
||||||
|
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
|
||||||
|
|
||||||
|
**Outcomes:**
|
||||||
|
- Macros can reference other macros defined earlier in the config
|
||||||
|
- Macro substitution is deterministic and order-dependent
|
||||||
|
- Single-pass substitution prevents circular dependencies
|
||||||
|
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
|
||||||
|
- All existing tests pass
|
||||||
|
- New tests validate substitution order and self-reference detection
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### 1. YAML Parsing Strategy
|
||||||
|
- **Continue using:** `gopkg.in/yaml.v3` (current library)
|
||||||
|
- **Use:** `yaml.Node` for ordered parsing of macros
|
||||||
|
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
|
||||||
|
|
||||||
|
### 2. Data Structure Changes
|
||||||
|
|
||||||
|
#### Current Implementation (config.go:19)
|
||||||
|
```go
|
||||||
|
type MacroList map[string]any
|
||||||
|
```
|
||||||
|
|
||||||
|
#### New Implementation
|
||||||
|
```go
|
||||||
|
type MacroList []MacroEntry
|
||||||
|
|
||||||
|
type MacroEntry struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
|
||||||
|
|
||||||
|
### 3. Macro Substitution Order Rules
|
||||||
|
|
||||||
|
The substitution must follow this hierarchy (from most specific to least):
|
||||||
|
|
||||||
|
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
|
||||||
|
2. **Model-level macros** (middle): Defined in specific model config, overrides global
|
||||||
|
3. **Global macros** (first): Defined at config root level
|
||||||
|
|
||||||
|
Within each level, macros are substituted in **reverse definition order** (LIFO):
|
||||||
|
- The last macro defined is substituted first
|
||||||
|
- This allows later macros to reference earlier ones
|
||||||
|
- Single-pass substitution prevents circular dependencies
|
||||||
|
|
||||||
|
### 4. Macro Reference Rules
|
||||||
|
|
||||||
|
**Allowed:**
|
||||||
|
- Macro can reference any macro defined **before** it (earlier in the file)
|
||||||
|
- Model macros can reference global macros
|
||||||
|
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
|
||||||
|
|
||||||
|
**Prohibited:**
|
||||||
|
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
|
||||||
|
- Macro cannot reference macros defined **after** it
|
||||||
|
- No circular references (prevented by single-pass, ordered substitution)
|
||||||
|
|
||||||
|
### 5. Validation Requirements
|
||||||
|
|
||||||
|
Add validation to detect:
|
||||||
|
- **Self-references:** Macro value contains reference to its own name
|
||||||
|
- **Unknown macros:** After substitution, any remaining `${...}` references
|
||||||
|
|
||||||
|
Error messages should be clear:
|
||||||
|
```
|
||||||
|
macro 'foo' contains self-reference
|
||||||
|
unknown macro '${bar}' in model.cmd
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Implementation Changes
|
||||||
|
|
||||||
|
#### Files to Modify
|
||||||
|
|
||||||
|
1. **[proxy/config/config.go](proxy/config/config.go)**
|
||||||
|
- Line 19: Change `MacroList` type definition
|
||||||
|
- Line 69: Update `Macros MacroList` field
|
||||||
|
- Line 153-157: Update macro validation loop to work with ordered structure
|
||||||
|
- Line 175-188: Update model-level macro validation
|
||||||
|
- Line 181-188: **NEW** Implement proper macro merging respecting order
|
||||||
|
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
|
||||||
|
- Line 389-415: Update `validateMacro` to detect self-references
|
||||||
|
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
|
||||||
|
|
||||||
|
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
|
||||||
|
- Line 33: Update `Macros MacroList` field type
|
||||||
|
|
||||||
|
3. **All test files**
|
||||||
|
- Update test fixtures to use ordered macro definitions
|
||||||
|
- Ensure tests specify macro order explicitly
|
||||||
|
|
||||||
|
#### Core Algorithm
|
||||||
|
|
||||||
|
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Merge global config and model macros. Model macros take precedence
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
|
||||||
|
|
||||||
|
// Add global macros first
|
||||||
|
for _, entry := range config.Macros {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add model macros (can override global)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
// Remove any existing global macro with same name
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry // Override
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add reserved MODEL_ID macro at the end
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
|
||||||
|
// Check if PORT macro is needed
|
||||||
|
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||||
|
// enforce ${PORT} used in both cmd and proxy
|
||||||
|
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add PORT macro to the end (highest priority)
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
|
||||||
|
// This allows later macros to reference earlier ones
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
// Substitute in command fields
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in metadata (recursive)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Add this new helper function to replace `substituteMetadataMacros`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. Self-Reference Detection
|
||||||
|
|
||||||
|
Add to `validateMacro` function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
// ... existing validation ...
|
||||||
|
|
||||||
|
// Check for self-reference
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(str, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
### 1. Migration Tests
|
||||||
|
- **Test:** All existing macro tests still pass after YAML library migration
|
||||||
|
- **Files:** All `*_test.go` files with macro tests
|
||||||
|
|
||||||
|
### 2. Macro Order Tests
|
||||||
|
|
||||||
|
#### Test: Macro-in-macro substitution order
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"A": "value-A"
|
||||||
|
"B": "prefix-${A}-suffix"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "echo ${B}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
|
||||||
|
|
||||||
|
#### Test: LIFO substitution order
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"base": "/models"
|
||||||
|
"path": "${base}/llama"
|
||||||
|
"full": "${path}/model.gguf"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "load ${full}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
|
||||||
|
|
||||||
|
#### Test: Model macro overrides global
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"tag": "global"
|
||||||
|
"msg": "value-${tag}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
"tag": "model-level"
|
||||||
|
cmd: "echo ${msg}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
|
||||||
|
|
||||||
|
### 3. Reserved Macro Tests
|
||||||
|
|
||||||
|
#### Test: MODEL_ID substituted in macro
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: "${podman-llama} -m model.gguf"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
|
||||||
|
|
||||||
|
### 4. Error Detection Tests
|
||||||
|
|
||||||
|
#### Test: Self-reference detection
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"recursive": "value-${recursive}"
|
||||||
|
```
|
||||||
|
**Expected:** Error: `macro 'recursive' contains self-reference`
|
||||||
|
|
||||||
|
#### Test: Undefined macro reference
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"A": "value-${UNDEFINED}"
|
||||||
|
```
|
||||||
|
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
|
||||||
|
|
||||||
|
### 5. Regression Tests
|
||||||
|
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
|
||||||
|
- Ensure all pass without modification (except test fixtures if needed)
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
### Phase 1: Data Structure Changes
|
||||||
|
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
|
||||||
|
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
|
||||||
|
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
|
||||||
|
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
|
||||||
|
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
|
||||||
|
- [ ] Implement helper functions:
|
||||||
|
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
|
||||||
|
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
|
||||||
|
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
|
||||||
|
|
||||||
|
### Phase 2: Macro Validation Updates
|
||||||
|
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
|
||||||
|
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
|
||||||
|
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
|
||||||
|
- [ ] Test self-reference detection with new test case
|
||||||
|
|
||||||
|
### Phase 3: Macro Substitution Algorithm
|
||||||
|
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
|
||||||
|
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
|
||||||
|
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
|
||||||
|
- [ ] Substitute in metadata within same loop
|
||||||
|
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
|
||||||
|
- [ ] Ensure `PORT` is added after port assignment (if needed)
|
||||||
|
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
|
||||||
|
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
|
||||||
|
|
||||||
|
### Phase 4: Testing
|
||||||
|
- [ ] Run `make test-dev` - fix any static checking errors
|
||||||
|
- [ ] Add test: macro-in-macro basic substitution
|
||||||
|
- [ ] Add test: LIFO substitution order with 3+ macro levels
|
||||||
|
- [ ] Add test: MODEL_ID in global macro used by model
|
||||||
|
- [ ] Add test: PORT in global macro used by model
|
||||||
|
- [ ] Add test: model macro overrides global macro in substitution
|
||||||
|
- [ ] Add test: self-reference detection error
|
||||||
|
- [ ] Add test: undefined macro reference error
|
||||||
|
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
|
||||||
|
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
|
||||||
|
|
||||||
|
### Phase 5: Documentation
|
||||||
|
- [ ] Update plan status in this file (mark completed)
|
||||||
|
- [ ] Update CLAUDE.md if macro behavior needs documentation
|
||||||
|
- [ ] Verify no new error messages need user documentation
|
||||||
|
|
||||||
|
## Bug Example (Original Issue)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"podman-llama": >
|
||||||
|
podman run --name ${MODEL_ID}
|
||||||
|
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server-cuda
|
||||||
|
|
||||||
|
"standard-options": >
|
||||||
|
--no-mmap --jinja
|
||||||
|
|
||||||
|
"kv8": >
|
||||||
|
-fa on -ctk q8_0 -ctv q8_0
|
||||||
|
```
|
||||||
|
|
||||||
|
**Current Bug:**
|
||||||
|
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
|
||||||
|
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
|
||||||
|
|
||||||
|
**After Fix:**
|
||||||
|
- Macros substituted in LIFO order: `kv8` → `standard-options` → `podman-llama`
|
||||||
|
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
|
||||||
|
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// created for issue: #252 https://github.com/mostlygeek/llama-swap/issues/252
|
||||||
|
// this simple benchmark tool sends a lot of small chat completion requests to llama-swap
|
||||||
|
// to make sure all the requests are accounted for.
|
||||||
|
//
|
||||||
|
// requests can be sent in parallel, and the tool will report the results.
|
||||||
|
// usage: go run main.go -baseurl http://localhost:8080/v1 -model llama3 -requests 1000 -par 5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// ----- CLI arguments ----------------------------------------------------
|
||||||
|
var (
|
||||||
|
baseurl string
|
||||||
|
modelName string
|
||||||
|
totalRequests int
|
||||||
|
parallelization int
|
||||||
|
)
|
||||||
|
|
||||||
|
flag.StringVar(&baseurl, "baseurl", "http://localhost:8080/v1", "Base URL of the API (e.g., https://api.example.com)")
|
||||||
|
flag.StringVar(&modelName, "model", "", "Model name to use")
|
||||||
|
flag.IntVar(&totalRequests, "requests", 1, "Total number of requests to send")
|
||||||
|
flag.IntVar(¶llelization, "par", 1, "Maximum number of concurrent requests")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if baseurl == "" || modelName == "" {
|
||||||
|
fmt.Println("Error: both -baseurl and -model are required.")
|
||||||
|
flag.Usage()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if totalRequests <= 0 {
|
||||||
|
fmt.Println("Error: -requests must be greater than 0.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if parallelization <= 0 {
|
||||||
|
fmt.Println("Error: -parallelization must be greater than 0.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- HTTP client -------------------------------------------------------
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Tracking response codes -------------------------------------------
|
||||||
|
statusCounts := make(map[int]int) // map[statusCode]count
|
||||||
|
var mu sync.Mutex // protects statusCounts
|
||||||
|
|
||||||
|
// ----- Request queue (buffered channel) ----------------------------------
|
||||||
|
requests := make(chan int, 10) // Buffered channel with capacity 10
|
||||||
|
|
||||||
|
// Goroutine to fill the request queue
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < totalRequests; i++ {
|
||||||
|
requests <- i + 1
|
||||||
|
}
|
||||||
|
close(requests)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// ----- Worker pool -------------------------------------------------------
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < parallelization; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(workerID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
for reqID := range requests {
|
||||||
|
// Build request payload as a single line JSON string
|
||||||
|
payload := `{"model":"` + modelName + `","max_tokens":100,"stream":false,"messages":[{"role":"user","content":"write a snake game in python"}]}`
|
||||||
|
|
||||||
|
// Send POST request
|
||||||
|
req, err := http.NewRequest(http.MethodPost,
|
||||||
|
fmt.Sprintf("%s/chat/completions", baseurl),
|
||||||
|
bytes.NewReader([]byte(payload)))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[worker %d][req %d] request creation error: %v", workerID, reqID, err)
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[-1]++
|
||||||
|
mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[worker %d][req %d] HTTP request error: %v", workerID, reqID, err)
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[-1]++
|
||||||
|
mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// Record status code
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[resp.StatusCode]++
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Status ticker (prints every second) -------------------------------
|
||||||
|
done := make(chan struct{})
|
||||||
|
tickerDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
startTime := time.Now()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
mu.Lock()
|
||||||
|
// Compute how many requests have completed so far
|
||||||
|
completed := 0
|
||||||
|
for _, cnt := range statusCounts {
|
||||||
|
completed += cnt
|
||||||
|
}
|
||||||
|
// Calculate duration and progress
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
progress := completed * 100 / totalRequests
|
||||||
|
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, progress)
|
||||||
|
mu.Unlock()
|
||||||
|
case <-done:
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, 100)
|
||||||
|
close(tickerDone)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for all workers to finish
|
||||||
|
wg.Wait()
|
||||||
|
close(done) // stops the status-update goroutine
|
||||||
|
<-tickerDone // give ticker time to finish / print
|
||||||
|
|
||||||
|
// ----- Summary ------------------------------------------------------------
|
||||||
|
fmt.Println("\n\n=== HTTP response code summary ===")
|
||||||
|
mu.Lock()
|
||||||
|
for code, cnt := range statusCounts {
|
||||||
|
if code == -1 {
|
||||||
|
fmt.Printf("Client-side errors (no HTTP response): %d\n", cnt)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%d : %d\n", code, cnt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
**
|
||||||
|
Test how exec.Cmd.CommandContext behaves under certain conditions:*
|
||||||
|
|
||||||
|
- process is killed externally, what happens with cmd.Wait() *
|
||||||
|
✔︎ it returns. catches crashes.*
|
||||||
|
|
||||||
|
- process ignores SIGTERM*
|
||||||
|
✔︎ `kill()` is called after cmd.WaitDelay*
|
||||||
|
|
||||||
|
- this process exits, what happens with children (kill -9 <this process' pid>)*
|
||||||
|
x they stick around. have to be manually killed.*
|
||||||
|
|
||||||
|
- .WithTimeout()'s cancel is called *
|
||||||
|
✔︎ process is killed after it ignores sigterm, cmd.Wait() catches it.*
|
||||||
|
|
||||||
|
- parent receives SIGINT/SIGTERM, what happens
|
||||||
|
✔︎ waits for child process to exit, then exits gracefully.
|
||||||
|
*/
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
// swap between these to use kill -9 <pid> on the cli to sim external crash
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
//ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
//cmd := exec.CommandContext(ctx, "sleep", "1")
|
||||||
|
cmd := exec.CommandContext(ctx,
|
||||||
|
"../../build/simple-responder_darwin_arm64",
|
||||||
|
//"-ignore-sig-term", /* so it doesn't exit on receiving SIGTERM, test cmd.WaitTimeout */
|
||||||
|
)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
// set a wait delay before signing sig kill
|
||||||
|
cmd.WaitDelay = 500 * time.Millisecond
|
||||||
|
cmd.Cancel = func() error {
|
||||||
|
fmt.Println("✔︎ Cancel() called, sending SIGTERM")
|
||||||
|
cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
|
||||||
|
//return nil
|
||||||
|
|
||||||
|
// this error is returned by cmd.Wait(), and can be used to
|
||||||
|
// single an error when the process couldn't be normally terminated
|
||||||
|
// but since a SIGTERM is sent, it's probably ok to return a nil
|
||||||
|
// as WaitDelay timing out will override the any error set here.
|
||||||
|
//
|
||||||
|
// test by enabling/disabling -ignore-sig-term on the process
|
||||||
|
// with -ignore-sig-term enabled, cmd.Wait() will have "signal: killed"
|
||||||
|
// without it, it will show the "new error from cancel"
|
||||||
|
return errors.New("error from cmd.Cancel()") // sets error returned by cmd.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
fmt.Println("Error starting process:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// catch signals. Calls cancel() which will cause cmd.Wait() to return and
|
||||||
|
// this program to eventually exit gracefully.
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
signal := <-sigChan
|
||||||
|
fmt.Printf("✔︎ Received signal: %d, Killing process... with cancel before exiting\n", signal)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Printf("✔︎ Parent Pid: %d, Process Pid: %d\n", os.Getpid(), cmd.Process.Pid)
|
||||||
|
fmt.Println("✔︎ Process started, cmd.Wait() ... ")
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
fmt.Println("✔︎ cmd.Wait returned, Error:", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("✔︎ cmd.Wait returned, Process exited on its own")
|
||||||
|
}
|
||||||
|
fmt.Println("✔︎ Child process exited, Done.")
|
||||||
|
}
|
||||||
@@ -35,17 +35,90 @@ func main() {
|
|||||||
|
|
||||||
// Set up the handler function using the provided response message
|
// Set up the handler function using the provided response message
|
||||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "application/json")
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
|
|
||||||
// add a wait to simulate a slow query
|
// Check if streaming is requested
|
||||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
// Query is checked instead of JSON body since that event stream conflicts with other tests
|
||||||
time.Sleep(wait)
|
isStreaming := c.Query("stream") == "true"
|
||||||
|
|
||||||
|
if isStreaming {
|
||||||
|
// Set headers for streaming
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
// add a wait to simulate a slow query
|
||||||
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
|
time.Sleep(wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 10 "asdf" tokens
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := gin.H{
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"choices": []gin.H{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": gin.H{
|
||||||
|
"content": "asdf",
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.SSEvent("message", data)
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final data with usage info
|
||||||
|
finalData := gin.H{
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
// add timings to simulate llama.cpp
|
||||||
|
"timings": gin.H{
|
||||||
|
"prompt_n": 25,
|
||||||
|
"prompt_ms": 13,
|
||||||
|
"predicted_n": 10,
|
||||||
|
"predicted_ms": 17,
|
||||||
|
"predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.SSEvent("message", finalData)
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
// Send [DONE]
|
||||||
|
c.SSEvent("message", "[DONE]")
|
||||||
|
c.Writer.Flush()
|
||||||
|
} else {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// add a wait to simulate a slow query
|
||||||
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
|
time.Sleep(wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||||
|
"request_body": string(bodyBytes),
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
"timings": gin.H{
|
||||||
|
"prompt_n": 25,
|
||||||
|
"prompt_ms": 13,
|
||||||
|
"predicted_n": 10,
|
||||||
|
"predicted_ms": 17,
|
||||||
|
"predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"responseMessage": *responseMessage,
|
|
||||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// for issue #62 to check model name strips profile slug
|
// for issue #62 to check model name strips profile slug
|
||||||
@@ -71,10 +144,28 @@ func main() {
|
|||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"responseMessage": *responseMessage,
|
"responseMessage": *responseMessage,
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// llama-server compatibility: /completion
|
||||||
|
r.POST("/completion", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
// issue #41
|
// issue #41
|
||||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||||
// Parse the multipart form
|
// Parse the multipart form
|
||||||
@@ -119,6 +210,11 @@ func main() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.GET("/v1/audio/voices", func(c *gin.Context) {
|
||||||
|
model := c.Query("model")
|
||||||
|
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
|
||||||
|
})
|
||||||
|
|
||||||
r.GET("/slow-respond", func(c *gin.Context) {
|
r.GET("/slow-respond", func(c *gin.Context) {
|
||||||
echo := c.Query("echo")
|
echo := c.Query("echo")
|
||||||
delay := c.Query("delay")
|
delay := c.Query("delay")
|
||||||
@@ -223,13 +319,13 @@ runloop:
|
|||||||
if countSigInt > 1 {
|
if countSigInt > 1 {
|
||||||
break runloop
|
break runloop
|
||||||
} else {
|
} else {
|
||||||
log.Println("Recieved SIGINT, send another SIGINT to shutdown")
|
log.Println("Received SIGINT, send another SIGINT to shutdown")
|
||||||
}
|
}
|
||||||
case syscall.SIGTERM:
|
case syscall.SIGTERM:
|
||||||
if *ignoreSigTerm {
|
if *ignoreSigTerm {
|
||||||
log.Println("Ignoring SIGTERM")
|
log.Println("Ignoring SIGTERM")
|
||||||
} else {
|
} else {
|
||||||
log.Println("Recieved SIGTERM, shutting down")
|
log.Println("Received SIGTERM, shutting down")
|
||||||
break runloop
|
break runloop
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
# wol-proxy
|
||||||
|
|
||||||
|
wol-proxy automatically wakes up a suspended llama-swap server using Wake-on-LAN when requests are received.
|
||||||
|
|
||||||
|
When a request arrives and llama-swap is unavailable, wol-proxy sends a WOL packet and holds the request until the server becomes available. If the server doesn't respond within the timeout period (default: 60 seconds), the request is dropped.
|
||||||
|
|
||||||
|
This utility helps conserve energy by allowing GPU-heavy servers to remain suspended when idle, as they can consume hundreds of watts even when not actively processing requests.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# minimal
|
||||||
|
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080
|
||||||
|
|
||||||
|
# everything
|
||||||
|
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080 \
|
||||||
|
# use debug log level
|
||||||
|
-log debug \
|
||||||
|
# altenerative listening port
|
||||||
|
-listen localhost:9999 \
|
||||||
|
# seconds to hold requests waiting for upstream to be ready
|
||||||
|
-timeout 30
|
||||||
|
```
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
`GET /status` - that's it. Everything else is proxied to the upstream server.
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<title>Loading...</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: sans-serif;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
height: 100vh;
|
||||||
|
margin: 0;
|
||||||
|
background: #f5f5f5;
|
||||||
|
}
|
||||||
|
.loader {
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.stats {
|
||||||
|
font-size: 18px;
|
||||||
|
color: #333;
|
||||||
|
margin: 20px 0;
|
||||||
|
}
|
||||||
|
.stats-label {
|
||||||
|
color: #666;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="loader">
|
||||||
|
<p>Waking up upstream server...</p>
|
||||||
|
<div class="stats">
|
||||||
|
<div><span class="stats-label">Time elapsed:</span> <span id="elapsed">0s</span></div>
|
||||||
|
<div><span id="attempts"> </span></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<script>
|
||||||
|
var startTime = Date.now();
|
||||||
|
var attempts = 0;
|
||||||
|
|
||||||
|
setInterval(function() {
|
||||||
|
var elapsed = (Date.now() - startTime) / 1000;
|
||||||
|
document.getElementById('elapsed').textContent = elapsed.toFixed(1) + 's';
|
||||||
|
}, 100);
|
||||||
|
|
||||||
|
// Check status every second
|
||||||
|
setInterval(function() {
|
||||||
|
attempts++;
|
||||||
|
var dots = '.'.repeat((attempts % 10) || 10);
|
||||||
|
document.getElementById('attempts').textContent = dots;
|
||||||
|
|
||||||
|
fetch('/status')
|
||||||
|
.then(function(r) { return r.text(); })
|
||||||
|
.then(function(t) {
|
||||||
|
if (t.indexOf('status: ready') !== -1) {
|
||||||
|
location.reload();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch(function() {});
|
||||||
|
}, 1000);
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,333 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed index.html
|
||||||
|
var loadingPageHTML string
|
||||||
|
|
||||||
|
var (
|
||||||
|
flagMac = flag.String("mac", "", "mac address to send WoL packet to")
|
||||||
|
flagUpstream = flag.String("upstream", "", "upstream proxy address to send requests to")
|
||||||
|
flagListen = flag.String("listen", ":8080", "listen address to listen on")
|
||||||
|
flagLog = flag.String("log", "info", "log level (debug, info, warn, error)")
|
||||||
|
flagTimeout = flag.Int("timeout", 60, "seconds requests wait for upstream response before failing")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
switch *flagLog {
|
||||||
|
case "debug":
|
||||||
|
slog.SetLogLoggerLevel(slog.LevelDebug)
|
||||||
|
case "info":
|
||||||
|
slog.SetLogLoggerLevel(slog.LevelInfo)
|
||||||
|
case "warn":
|
||||||
|
slog.SetLogLoggerLevel(slog.LevelWarn)
|
||||||
|
case "error":
|
||||||
|
slog.SetLogLoggerLevel(slog.LevelError)
|
||||||
|
default:
|
||||||
|
slog.Error("invalid log level", "logLevel", *flagLog)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate flags
|
||||||
|
if *flagListen == "" {
|
||||||
|
slog.Error("listen address is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if *flagMac == "" {
|
||||||
|
slog.Error("mac address is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if *flagTimeout < 1 {
|
||||||
|
slog.Error("timeout must be greater than 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var upstreamURL *url.URL
|
||||||
|
var err error
|
||||||
|
// validate mac address
|
||||||
|
if _, err = net.ParseMAC(*flagMac); err != nil {
|
||||||
|
slog.Error("invalid mac address", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if *flagUpstream == "" {
|
||||||
|
slog.Error("upstream proxy address is required")
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
upstreamURL, err = url.ParseRequestURI(*flagUpstream)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("error parsing upstream url", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := newProxy(upstreamURL)
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: *flagListen,
|
||||||
|
Handler: proxy,
|
||||||
|
}
|
||||||
|
|
||||||
|
// start the server
|
||||||
|
go func() {
|
||||||
|
slog.Info("server starting on", "address", *flagListen)
|
||||||
|
if err := server.ListenAndServe(); err != nil {
|
||||||
|
slog.Error("error starting server", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// graceful shutdown
|
||||||
|
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||||
|
<-ctx.Done()
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
notready upstreamStatus = "not ready"
|
||||||
|
ready upstreamStatus = "ready"
|
||||||
|
)
|
||||||
|
|
||||||
|
type proxyServer struct {
|
||||||
|
upstreamProxy *httputil.ReverseProxy
|
||||||
|
failCount int
|
||||||
|
statusMutex sync.RWMutex
|
||||||
|
status upstreamStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProxy(url *url.URL) *proxyServer {
|
||||||
|
p := httputil.NewSingleHostReverseProxy(url)
|
||||||
|
proxy := &proxyServer{
|
||||||
|
upstreamProxy: p,
|
||||||
|
status: notready,
|
||||||
|
failCount: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// start a goroutine to monitor upstream status via SSE
|
||||||
|
go func() {
|
||||||
|
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: 0, // No timeout for SSE connection
|
||||||
|
}
|
||||||
|
|
||||||
|
waitDuration := 10 * time.Second
|
||||||
|
|
||||||
|
for {
|
||||||
|
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", eventsUrl, nil)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to create SSE request", "error", err)
|
||||||
|
proxy.setStatus(notready)
|
||||||
|
proxy.incFail(1)
|
||||||
|
time.Sleep(waitDuration)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
req.Header.Set("Cache-Control", "no-cache")
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to connect to SSE endpoint", "error", err)
|
||||||
|
proxy.setStatus(notready)
|
||||||
|
proxy.incFail(1)
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
|
||||||
|
_, _ = io.Copy(io.Discard, resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
proxy.setStatus(notready)
|
||||||
|
proxy.incFail(1)
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successfully connected to SSE endpoint
|
||||||
|
slog.Info("connected to SSE endpoint, upstream ready")
|
||||||
|
proxy.setStatus(ready)
|
||||||
|
proxy.resetFailures()
|
||||||
|
|
||||||
|
// Read from the SSE stream to detect disconnection
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
|
||||||
|
// use a fairly large buffer to avoid scanner errors when reading large SSE events
|
||||||
|
buf := make([]byte, 0, 1024*1024*2)
|
||||||
|
scanner.Buffer(buf, 1024*1024*2)
|
||||||
|
events := 0
|
||||||
|
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
fmt.Print("Events: ")
|
||||||
|
}
|
||||||
|
for scanner.Scan() {
|
||||||
|
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
// Just read the events to keep connection alive
|
||||||
|
// We don't need to process the event data
|
||||||
|
events++
|
||||||
|
fmt.Printf("%d, ", events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
slog.Error("error reading from SSE stream", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection closed or error occurred
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
slog.Info("SSE connection closed, upstream not ready")
|
||||||
|
proxy.setStatus(notready)
|
||||||
|
proxy.incFail(1)
|
||||||
|
|
||||||
|
// Wait before reconnecting
|
||||||
|
time.Sleep(waitDuration)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "GET" && r.URL.Path == "/status" {
|
||||||
|
status := string(p.getStatus())
|
||||||
|
failCount := p.getFailures()
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
fmt.Fprintf(w, "status: %s\n", status)
|
||||||
|
fmt.Fprintf(w, "failures: %d\n", failCount)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.getStatus() == notready {
|
||||||
|
path := r.URL.Path
|
||||||
|
if strings.HasPrefix(path, "/api/events") {
|
||||||
|
slog.Debug("Skipping wake up", "req", path)
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
|
||||||
|
if err := sendMagicPacket(*flagMac); err != nil {
|
||||||
|
slog.Warn("failed to send magic WoL packet", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For root or UI path requests, return loading page with status polling
|
||||||
|
// the web page will do the polling and redirect when ready
|
||||||
|
if path == "/" || strings.HasPrefix(path, "/ui/") {
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, loadingPageHTML)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
timeout, cancel := context.WithTimeout(context.Background(), time.Duration(*flagTimeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout.Done():
|
||||||
|
slog.Info("timeout waiting for upstream to be ready")
|
||||||
|
http.Error(w, "timeout", http.StatusRequestTimeout)
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if p.getStatus() == ready {
|
||||||
|
ticker.Stop()
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.upstreamProxy.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) getStatus() upstreamStatus {
|
||||||
|
p.statusMutex.RLock()
|
||||||
|
defer p.statusMutex.RUnlock()
|
||||||
|
return p.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) setStatus(status upstreamStatus) {
|
||||||
|
p.statusMutex.Lock()
|
||||||
|
defer p.statusMutex.Unlock()
|
||||||
|
p.status = status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) incFail(num int) {
|
||||||
|
p.statusMutex.Lock()
|
||||||
|
defer p.statusMutex.Unlock()
|
||||||
|
p.failCount += num
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) getFailures() int {
|
||||||
|
p.statusMutex.RLock()
|
||||||
|
defer p.statusMutex.RUnlock()
|
||||||
|
return p.failCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyServer) resetFailures() {
|
||||||
|
p.statusMutex.Lock()
|
||||||
|
defer p.statusMutex.Unlock()
|
||||||
|
p.failCount = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendMagicPacket(macAddr string) error {
|
||||||
|
hwAddr, err := net.ParseMAC(macAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hwAddr) != 6 {
|
||||||
|
return errors.New("invalid MAC address")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the magic packet.
|
||||||
|
packet := make([]byte, 102)
|
||||||
|
// Add 6 bytes of 0xFF.
|
||||||
|
for i := 0; i < 6; i++ {
|
||||||
|
packet[i] = 0xFF
|
||||||
|
}
|
||||||
|
// Repeat the MAC address 16 times.
|
||||||
|
for i := 1; i <= 16; i++ {
|
||||||
|
copy(packet[i*6:], hwAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the packet using UDP.
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: net.IPv4bcast,
|
||||||
|
Port: 9,
|
||||||
|
}
|
||||||
|
conn, err := net.DialUDP("udp", nil, &addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, err = conn.Write(packet)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,362 @@
|
|||||||
|
{
|
||||||
|
"$schema": "https://json-schema.org/draft-07/schema#",
|
||||||
|
"$id": "llama-swap-config-schema.json",
|
||||||
|
"title": "llama-swap configuration",
|
||||||
|
"description": "Configuration file for llama-swap",
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"models"
|
||||||
|
],
|
||||||
|
"definitions": {
|
||||||
|
"macros": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 0,
|
||||||
|
"maxLength": 1024
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"propertyNames": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
"maxLength": 64,
|
||||||
|
"pattern": "^[a-zA-Z0-9_-]+$",
|
||||||
|
"not": {
|
||||||
|
"enum": [
|
||||||
|
"PORT",
|
||||||
|
"MODEL_ID"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"default": {},
|
||||||
|
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"properties": {
|
||||||
|
"healthCheckTimeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 15,
|
||||||
|
"default": 120,
|
||||||
|
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
||||||
|
},
|
||||||
|
"logLevel": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"debug",
|
||||||
|
"info",
|
||||||
|
"warn",
|
||||||
|
"error"
|
||||||
|
],
|
||||||
|
"default": "info",
|
||||||
|
"description": "Sets the logging value. Valid values: debug, info, warn, error."
|
||||||
|
},
|
||||||
|
"logTimeFormat": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"",
|
||||||
|
"ansic",
|
||||||
|
"unixdate",
|
||||||
|
"rubydate",
|
||||||
|
"rfc822",
|
||||||
|
"rfc822z",
|
||||||
|
"rfc850",
|
||||||
|
"rfc1123",
|
||||||
|
"rfc1123z",
|
||||||
|
"rfc3339",
|
||||||
|
"rfc3339nano",
|
||||||
|
"kitchen",
|
||||||
|
"stamp",
|
||||||
|
"stampmilli",
|
||||||
|
"stampmicro",
|
||||||
|
"stampnano"
|
||||||
|
],
|
||||||
|
"default": "",
|
||||||
|
"description": "Enables and sets the logging timestamp format. Valid values: \"\", \"ansic\", \"unixdate\", \"rubydate\", \"rfc822\", \"rfc822z\", \"rfc850\", \"rfc1123\", \"rfc1123z\", \"rfc3339\", \"rfc3339nano\", \"kitchen\", \"stamp\", \"stampmilli\", \"stampmicro\", and \"stampnano\". For more info, read: https://pkg.go.dev/time#pkg-constants"
|
||||||
|
},
|
||||||
|
"metricsMaxInMemory": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": 1000,
|
||||||
|
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
||||||
|
},
|
||||||
|
"captureBuffer": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 5,
|
||||||
|
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
|
||||||
|
},
|
||||||
|
"startPort": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": 5800,
|
||||||
|
"description": "Starting port number for the automatic ${PORT} macro. The ${PORT} macro is incremented for every model that uses it."
|
||||||
|
},
|
||||||
|
"sendLoadingState": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false,
|
||||||
|
"description": "Inject loading status updates into the reasoning field. When true, a stream of loading messages will be sent to the client."
|
||||||
|
},
|
||||||
|
"includeAliasesInList": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false,
|
||||||
|
"description": "Present aliases within the /v1/models OpenAI API listing. when true, model aliases will be output to the API model listing duplicating all fields except for Id so chat UIs can use the alias equivalent to the original."
|
||||||
|
},
|
||||||
|
"macros": {
|
||||||
|
"$ref": "#/definitions/macros"
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "A dictionary of model configurations. Each key is a model's ID. Model settings have defaults if not defined. The model's ID is available as ${MODEL_ID}.",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"cmd"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"macros": {
|
||||||
|
"$ref": "#/definitions/macros"
|
||||||
|
},
|
||||||
|
"cmd": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
"description": "Command to run to start the inference server. Macros can be used. Comments allowed with |."
|
||||||
|
},
|
||||||
|
"cmdStop": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"description": "Command to run to stop the model gracefully. Uses ${PID} macro for upstream process id. If empty, default shutdown behavior is used."
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"maxLength": 128,
|
||||||
|
"description": "Display name for the model. Used in v1/models API response."
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"maxLength": 1024,
|
||||||
|
"description": "Description for the model. Used in v1/models API response."
|
||||||
|
},
|
||||||
|
"env": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^[A-Z_][A-Z0-9_]*=.*$"
|
||||||
|
},
|
||||||
|
"default": [],
|
||||||
|
"description": "Array of environment variables to inject into cmd's environment. Each value is a string in ENV_NAME=value format."
|
||||||
|
},
|
||||||
|
"proxy": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "http://localhost:${PORT}",
|
||||||
|
"format": "uri",
|
||||||
|
"description": "URL where llama-swap routes API requests. If custom port is used in cmd, this must be set."
|
||||||
|
},
|
||||||
|
"aliases": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1
|
||||||
|
},
|
||||||
|
"default": [],
|
||||||
|
"description": "Alternative model names for this configuration. Must be unique globally."
|
||||||
|
},
|
||||||
|
"checkEndpoint": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "/health",
|
||||||
|
"pattern": "^/.*$|^none$",
|
||||||
|
"description": "URL path to check if the server is ready. Use 'none' to skip health checking."
|
||||||
|
},
|
||||||
|
"ttl": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 0,
|
||||||
|
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
|
||||||
|
},
|
||||||
|
"useModelName": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"description": "Override the model name sent to upstream server. Useful if upstream expects a different name."
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stripParams": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||||
|
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||||
|
},
|
||||||
|
"setParams": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true,
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary of filter settings. Supports stripParams and setParams."
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true,
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary of arbitrary values included in /v1/models. Can contain complex types. Only passed through in /v1/models responses."
|
||||||
|
},
|
||||||
|
"concurrencyLimit": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"default": 0,
|
||||||
|
"description": "Overrides allowed number of active parallel requests to a model. 0 uses internal default of 10. >0 overrides default. Requests exceeding limit get HTTP 429."
|
||||||
|
},
|
||||||
|
"sendLoadingState": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
|
||||||
|
},
|
||||||
|
"unlisted": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false,
|
||||||
|
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"groups": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"members"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"swap": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": true,
|
||||||
|
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||||
|
},
|
||||||
|
"exclusive": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": true,
|
||||||
|
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||||
|
},
|
||||||
|
"persistent": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false,
|
||||||
|
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||||
|
},
|
||||||
|
"members": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||||
|
},
|
||||||
|
"hooks": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"on_startup": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"preload": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"default": [],
|
||||||
|
"description": "List of model IDs to load on startup. Model names must match keys in models. When preloading multiple models, define a group to prevent swapping."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"description": "Actions to perform on startup. Only supported action is preload."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||||
|
},
|
||||||
|
"logToStdout": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"proxy",
|
||||||
|
"upstream",
|
||||||
|
"both",
|
||||||
|
"none"
|
||||||
|
],
|
||||||
|
"default": "proxy",
|
||||||
|
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
|
||||||
|
},
|
||||||
|
"apiKeys": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1
|
||||||
|
},
|
||||||
|
"default": [],
|
||||||
|
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
|
||||||
|
},
|
||||||
|
"peers": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"proxy",
|
||||||
|
"models"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"proxy": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "uri",
|
||||||
|
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
|
||||||
|
},
|
||||||
|
"apiKey": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1
|
||||||
|
},
|
||||||
|
"description": "A list of models served by the peer."
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stripParams": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||||
|
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
|
||||||
|
},
|
||||||
|
"setParams": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true,
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"default": {},
|
||||||
|
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"default": {},
|
||||||
|
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,88 +1,422 @@
|
|||||||
# Seconds to wait for llama.cpp to be available to serve requests
|
# add this modeline for validation in vscode
|
||||||
# Default (and minimum): 15 seconds
|
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||||
healthCheckTimeout: 90
|
#
|
||||||
|
# llama-swap YAML configuration example
|
||||||
|
# -------------------------------------
|
||||||
|
#
|
||||||
|
# 💡 Tip - Use an LLM with this file!
|
||||||
|
# ====================================
|
||||||
|
# This example configuration is written to be LLM friendly. Try
|
||||||
|
# copying this file into an LLM and asking it to explain or generate
|
||||||
|
# sections for you.
|
||||||
|
# ====================================
|
||||||
|
|
||||||
# valid log levels: debug, info (default), warn, error
|
# Usage notes:
|
||||||
logLevel: debug
|
# - Below are all the available configuration options for llama-swap.
|
||||||
|
# - Settings noted as "required" must be in your configuration file
|
||||||
|
# - Settings noted as "optional" can be omitted
|
||||||
|
|
||||||
# creating a coding profile with models for code generation and general questions
|
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||||
groups:
|
# - optional, default: 120
|
||||||
coding:
|
# - minimum value is 15 seconds, anything less will be set to this value
|
||||||
swap: false
|
healthCheckTimeout: 500
|
||||||
members:
|
|
||||||
- "qwen"
|
|
||||||
- "llama"
|
|
||||||
|
|
||||||
|
# logLevel: sets the logging value
|
||||||
|
# - optional, default: info
|
||||||
|
# - Valid log levels: debug, info, warn, error
|
||||||
|
logLevel: info
|
||||||
|
|
||||||
|
# logTimeFormat: enables and sets the logging timestamp format
|
||||||
|
# - optional, default (disabled): ""
|
||||||
|
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||||
|
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||||
|
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||||
|
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||||
|
logTimeFormat: ""
|
||||||
|
|
||||||
|
# logToStdout: controls what is logged to stdout
|
||||||
|
# - optional, default: "proxy"
|
||||||
|
# - valid values:
|
||||||
|
# - "proxy": logs generated by llama-swap when swapping models,
|
||||||
|
# handling requests, etc.
|
||||||
|
# - "upstream": a copy of an upstream processes stdout logs
|
||||||
|
# - "both": both the proxy and upstream logs interleaved together
|
||||||
|
# - "none": no logs are ever written to stdout
|
||||||
|
logToStdout: "proxy"
|
||||||
|
|
||||||
|
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||||
|
# - optional, default: 1000
|
||||||
|
# - controls how many metrics are stored in memory before older ones are discarded
|
||||||
|
# - useful for limiting memory usage when processing large volumes of metrics
|
||||||
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
|
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||||
|
# - optional, default: 10
|
||||||
|
# - set to 0 to disable
|
||||||
|
captureBuffer: 15
|
||||||
|
|
||||||
|
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||||
|
# - optional, default: 5800
|
||||||
|
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||||
|
# - it is automatically incremented for every model that uses it
|
||||||
|
startPort: 10001
|
||||||
|
|
||||||
|
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||||
|
# field
|
||||||
|
# - optional, default: false
|
||||||
|
# - when true, a stream of loading messages will be sent to the client in the
|
||||||
|
# reasoning field so chat UIs can show that loading is in progress.
|
||||||
|
# - see #366 for more details
|
||||||
|
sendLoadingState: true
|
||||||
|
|
||||||
|
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||||
|
# - optional, default: false
|
||||||
|
# - when true, model aliases will be output to the API model listing duplicating
|
||||||
|
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||||
|
includeAliasesInList: false
|
||||||
|
|
||||||
|
# macros: a dictionary of string substitutions
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - macros are reusable snippets
|
||||||
|
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||||
|
# - useful for reducing common configuration settings
|
||||||
|
# - macro names are strings and must be less than 64 characters
|
||||||
|
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||||
|
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||||
|
# - macro values can be numbers, bools, or strings
|
||||||
|
# - macros can contain other macros, but they must be defined before they are used
|
||||||
|
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||||
|
# - env macros are substituted first, before regular macros
|
||||||
|
# - if the env var is not set, config loading will fail with an error
|
||||||
|
macros:
|
||||||
|
# Example of a multi-line macro
|
||||||
|
"latest-llama": >
|
||||||
|
/path/to/llama-server/llama-server-ec9e0301
|
||||||
|
--port ${PORT}
|
||||||
|
|
||||||
|
"default_ctx": 4096
|
||||||
|
|
||||||
|
# Example of macro-in-macro usage. macros can contain other macros
|
||||||
|
# but they must be previously declared.
|
||||||
|
"default_args": "--ctx-size ${default_ctx}"
|
||||||
|
|
||||||
|
# Example of environment variable macros
|
||||||
|
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||||
|
# - useful for paths, secrets, or machine-specific configuration
|
||||||
|
"models_dir": "${env.HOME}/models"
|
||||||
|
|
||||||
|
# apiKeys: require an API key when making requests to inference endpoints
|
||||||
|
# - optional, default: []
|
||||||
|
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||||
|
# - each key is a non-empty string
|
||||||
|
apiKeys:
|
||||||
|
- "sk-hunter2"
|
||||||
|
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||||
|
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||||
|
|
||||||
|
# use environment variable macros to keep secrets out of the config
|
||||||
|
- "${env.API_KEY_1}"
|
||||||
|
- "${env.API_KEY_2}"
|
||||||
|
|
||||||
|
# models: a dictionary of model configurations
|
||||||
|
# - required
|
||||||
|
# - each key is the model's ID, used in API requests
|
||||||
|
# - model settings have default values that are used if they are not defined here
|
||||||
|
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||||
|
# - below are examples of the all the settings a model can have
|
||||||
models:
|
models:
|
||||||
|
# keys are the model names used in API requests
|
||||||
"llama":
|
"llama":
|
||||||
cmd: >
|
# macros: a dictionary of string substitutions specific to this model
|
||||||
models/llama-server-osx
|
# - optional, default: empty dictionary
|
||||||
--port ${PORT}
|
# - macros defined here override macros defined in the global macros section
|
||||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
# - model level macros follow the same rules as global macros
|
||||||
|
macros:
|
||||||
|
"default_ctx": 16384
|
||||||
|
"temp": 0.7
|
||||||
|
|
||||||
# list of model name aliases this llama.cpp instance can serve
|
# cmd: the command to run to start the inference server.
|
||||||
|
# - required
|
||||||
|
# - it is just a string, similar to what you would run on the CLI
|
||||||
|
# - using `|` allows for comments in the command, these will be parsed out
|
||||||
|
# - macros can be used within cmd
|
||||||
|
cmd: |
|
||||||
|
# ${latest-llama} is a macro that is defined above
|
||||||
|
${latest-llama}
|
||||||
|
--model path/to/llama-8B-Q4_K_M.gguf
|
||||||
|
--ctx-size ${default_ctx}
|
||||||
|
--temperature ${temp}
|
||||||
|
|
||||||
|
# name: a display name for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
name: "llama 3.1 8B"
|
||||||
|
|
||||||
|
# description: a description for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
description: "A small but capable model used for quick testing"
|
||||||
|
|
||||||
|
# env: define an array of environment variables to inject into cmd's environment
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - each value is a single string
|
||||||
|
# - in the format: ENV_NAME=value
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||||
|
|
||||||
|
# proxy: the URL where llama-swap routes API requests
|
||||||
|
# - optional, default: http://localhost:${PORT}
|
||||||
|
# - if you used ${PORT} in cmd this can be omitted
|
||||||
|
# - if you use a custom port in cmd this *must* be set
|
||||||
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
|
# aliases: alternative model names that this model configuration is used for
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - aliases must be unique globally
|
||||||
|
# - useful for impersonating a specific model
|
||||||
aliases:
|
aliases:
|
||||||
- gpt-4o-mini
|
- "gpt-4o-mini"
|
||||||
|
- "gpt-3.5-turbo"
|
||||||
|
|
||||||
# check this path for a HTTP 200 response for the server to be ready
|
# checkEndpoint: URL path to check if the server is ready
|
||||||
checkEndpoint: /health
|
# - optional, default: /health
|
||||||
|
# - endpoint is expected to return an HTTP 200 response
|
||||||
|
# - all requests wait until the endpoint is ready or fails
|
||||||
|
# - use "none" to skip endpoint health checking
|
||||||
|
checkEndpoint: /custom-endpoint
|
||||||
|
|
||||||
# unload model after 5 seconds
|
# ttl: automatically unload the model after ttl seconds
|
||||||
ttl: 5
|
# - optional, default: 0
|
||||||
|
# - ttl values must be a value greater than 0
|
||||||
|
# - a value of 0 disables automatic unloading of the model
|
||||||
|
ttl: 60
|
||||||
|
|
||||||
"qwen":
|
# useModelName: override the model name that is sent to upstream server
|
||||||
cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
# - optional, default: ""
|
||||||
aliases:
|
# - useful for when the upstream server expects a specific model name that
|
||||||
- gpt-3.5-turbo
|
# is different from the model's ID
|
||||||
|
useModelName: "qwen:qwq"
|
||||||
|
|
||||||
# Embedding example with Nomic
|
# filters: a dictionary of filter settings
|
||||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
# - optional, default: empty dictionary
|
||||||
"nomic":
|
# - same capabilities as peer filters (stripParams, setParams)
|
||||||
cmd: >
|
filters:
|
||||||
models/llama-server-osx --port ${PORT}
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
# - optional, default: ""
|
||||||
--ctx-size 8192
|
# - useful for server side enforcement of sampling parameters
|
||||||
--batch-size 8192
|
# - the `model` parameter can never be removed
|
||||||
--rope-scaling yarn
|
# - can be any JSON key in the request body
|
||||||
--rope-freq-scale 0.75
|
# - recommended to stick to sampling parameters
|
||||||
-ngl 99
|
stripParams: "temperature, top_p, top_k"
|
||||||
--embeddings
|
|
||||||
|
|
||||||
# Reranking example with bge-reranker
|
# setParams: a dictionary of parameters to set/override in requests
|
||||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
# - optional, default: empty dictionary
|
||||||
"bge-reranker":
|
# - useful for enforcing specific parameter values
|
||||||
cmd: >
|
# - protected params like "model" cannot be overridden
|
||||||
models/llama-server-osx --port ${PORT}
|
# - values can be strings, numbers, booleans, arrays, or objects
|
||||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
setParams:
|
||||||
--ctx-size 8192
|
# Example: enforce specific sampling parameters
|
||||||
--reranking
|
temperature: 0.7
|
||||||
|
top_p: 0.9
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||||
"dockertest":
|
# - optional, default: empty dictionary
|
||||||
cmd: >
|
# - while metadata can contains complex types it is recommended to keep it simple
|
||||||
docker run --name dockertest
|
# - metadata is only passed through in /v1/models responses
|
||||||
|
metadata:
|
||||||
|
# port will remain an integer
|
||||||
|
port: ${PORT}
|
||||||
|
|
||||||
|
# the ${temp} macro will remain a float
|
||||||
|
temperature: ${temp}
|
||||||
|
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||||
|
|
||||||
|
a_list:
|
||||||
|
- 1
|
||||||
|
- 1.23
|
||||||
|
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||||
|
|
||||||
|
an_obj:
|
||||||
|
a: "1"
|
||||||
|
b: 2
|
||||||
|
# objects can contain complex types with macro substitution
|
||||||
|
# becomes: c: [0.7, false, "model: llama"]
|
||||||
|
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||||
|
|
||||||
|
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||||
|
# - optional, default: 0
|
||||||
|
# - useful for limiting the number of active parallel requests a model can process
|
||||||
|
# - must be set per model
|
||||||
|
# - any number greater than 0 will override the internal default value of 10
|
||||||
|
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||||
|
# - recommended to be omitted and the default used
|
||||||
|
concurrencyLimit: 0
|
||||||
|
|
||||||
|
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||||
|
# - optional, default: undefined (use global setting)
|
||||||
|
sendLoadingState: false
|
||||||
|
|
||||||
|
# Unlisted model example:
|
||||||
|
"qwen-unlisted":
|
||||||
|
# unlisted: boolean, true or false
|
||||||
|
# - optional, default: false
|
||||||
|
# - unlisted models do not show up in /v1/models api requests
|
||||||
|
# - can be requested as normal through all apis
|
||||||
|
unlisted: true
|
||||||
|
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
|
# Docker example:
|
||||||
|
# container runtimes like Docker and Podman can be used reliably with
|
||||||
|
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||||
|
"docker-llama":
|
||||||
|
proxy: "http://127.0.0.1:${PORT}"
|
||||||
|
cmd: |
|
||||||
|
docker run --name ${MODEL_ID}
|
||||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
"simple":
|
# cmdStop: command to run to stop the model gracefully
|
||||||
# example of setting environment variables
|
# - optional, default: ""
|
||||||
env:
|
# - useful for stopping commands managed by another system
|
||||||
- CUDA_VISIBLE_DEVICES=0,1
|
# - the upstream's process id is available in the ${PID} macro
|
||||||
- env1=hello
|
#
|
||||||
cmd: build/simple-responder --port ${PORT}
|
# When empty, llama-swap has this default behaviour:
|
||||||
unlisted: true
|
# - on POSIX systems: a SIGTERM signal is sent
|
||||||
|
# - on Windows, calls taskkill to stop the process
|
||||||
|
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||||
|
cmdStop: docker stop ${MODEL_ID}
|
||||||
|
|
||||||
# use "none" to skip check. Caution this may cause some requests to fail
|
# groups: a dictionary of group settings
|
||||||
# until the upstream server is ready for traffic
|
# - optional, default: empty dictionary
|
||||||
checkEndpoint: none
|
# - provides advanced controls over model swapping behaviour
|
||||||
|
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
||||||
|
# - model IDs must be defined in the Models section
|
||||||
|
# - a model can only be a member of one group
|
||||||
|
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||||
|
# - see issue #109 for details
|
||||||
|
#
|
||||||
|
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||||
|
groups:
|
||||||
|
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
||||||
|
# to run a time across the whole llama-swap instance
|
||||||
|
"group1":
|
||||||
|
# swap: controls the model swapping behaviour in within the group
|
||||||
|
# - optional, default: true
|
||||||
|
# - true : only one model is allowed to run at a time
|
||||||
|
# - false: all models can run together, no swapping
|
||||||
|
swap: true
|
||||||
|
|
||||||
# don't use these, just for testing if things are broken
|
# exclusive: controls how the group affects other groups
|
||||||
"broken":
|
# - optional, default: true
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
# - true: causes all other groups to unload when this group runs a model
|
||||||
proxy: http://127.0.0.1:8999
|
# - false: does not affect other groups
|
||||||
unlisted: true
|
exclusive: true
|
||||||
"broken_timeout":
|
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
# members references the models defined above
|
||||||
proxy: http://127.0.0.1:9000
|
# required
|
||||||
unlisted: true
|
members:
|
||||||
|
- "llama"
|
||||||
|
- "qwen-unlisted"
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# - in group2 all models can run at the same time
|
||||||
|
# - when a different group is loaded it causes all running models in this group to unload
|
||||||
|
"group2":
|
||||||
|
swap: false
|
||||||
|
|
||||||
|
# exclusive: false does not unload other groups when a model in group2 is requested
|
||||||
|
# - the models in group2 will be loaded but will not unload any other groups
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "docker-llama"
|
||||||
|
- "modelA"
|
||||||
|
- "modelB"
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# - a persistent group, prevents other groups from unloading it
|
||||||
|
"forever":
|
||||||
|
# persistent: prevents over groups from unloading the models in this group
|
||||||
|
# - optional, default: false
|
||||||
|
# - does not affect individual model behaviour
|
||||||
|
persistent: true
|
||||||
|
|
||||||
|
# set swap/exclusive to false to prevent swapping inside the group
|
||||||
|
# and the unloading of other groups
|
||||||
|
swap: false
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "forever-modelA"
|
||||||
|
- "forever-modelB"
|
||||||
|
- "forever-modelc"
|
||||||
|
|
||||||
|
# hooks: a dictionary of event triggers and actions
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - the only supported hook is on_startup
|
||||||
|
hooks:
|
||||||
|
# on_startup: a dictionary of actions to perform on startup
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - the only supported action is preload
|
||||||
|
on_startup:
|
||||||
|
# preload: a list of model ids to load on startup
|
||||||
|
# - optional, default: empty list
|
||||||
|
# - model names must match keys in the models sections
|
||||||
|
# - when preloading multiple models at once, define a group
|
||||||
|
# otherwise models will be loaded and swapped out
|
||||||
|
preload:
|
||||||
|
- "llama"
|
||||||
|
|
||||||
|
# peers: a dictionary of remote peers and models they provide
|
||||||
|
# - optional, default empty dictionary
|
||||||
|
# - peers can be another llama-swap
|
||||||
|
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||||
|
peers:
|
||||||
|
# keys is the peer'd ID
|
||||||
|
llama-swap-peer:
|
||||||
|
# proxy: a valid base URL to proxy requests to
|
||||||
|
# - required
|
||||||
|
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||||
|
proxy: http://192.168.1.23
|
||||||
|
# models: a list of models served by the peer
|
||||||
|
# - required
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
- model_b
|
||||||
|
- embeddings/model_c
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
# apiKey: a string key to be injected into the request
|
||||||
|
# - optional, default: ""
|
||||||
|
# - if blank, no key will be added to the request
|
||||||
|
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||||
|
# - can be a string or a macro
|
||||||
|
apiKey: ${env.OPENROUTER_API_KEY}
|
||||||
|
models:
|
||||||
|
- meta-llama/llama-3.1-8b-instruct
|
||||||
|
- qwen/qwen3-235b-a22b-2507
|
||||||
|
- deepseek/deepseek-v3.2
|
||||||
|
- z-ai/glm-4.7
|
||||||
|
- moonshotai/kimi-k2-0905
|
||||||
|
- minimax/minimax-m2.1
|
||||||
|
# filters: a dictionary of filter settings for peer requests
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - same capabilities as model filters (stripParams, setParams)
|
||||||
|
filters:
|
||||||
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
|
# - optional, default: ""
|
||||||
|
# - useful for removing parameters that the peer doesn't support
|
||||||
|
# - the `model` parameter can never be removed
|
||||||
|
stripParams: "temperature, top_p"
|
||||||
|
|
||||||
|
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - useful for injecting provider-specific settings like data retention policies
|
||||||
|
# - protected params like "model" cannot be overridden
|
||||||
|
# - values can be strings, numbers, booleans, arrays, or objects
|
||||||
|
setParams:
|
||||||
|
# Example: enforce zero-data-retention for OpenRouter
|
||||||
|
provider:
|
||||||
|
data_collection: "deny"
|
||||||
|
zdr: true
|
||||||
|
|||||||
@@ -1,55 +1,164 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
cd $(dirname "$0")
|
cd $(dirname "$0")
|
||||||
|
|
||||||
|
# use this to test locally, example:
|
||||||
|
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
|
||||||
|
# you need read:package scope on the token. Generate a personal access token with
|
||||||
|
# the scopes: gist, read:org, repo, write:packages
|
||||||
|
# then: gh auth login (and copy/paste the new token)
|
||||||
|
|
||||||
|
LOG_DEBUG=${LOG_DEBUG:-0}
|
||||||
|
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
|
||||||
|
|
||||||
|
log_debug() {
|
||||||
|
if [ "$LOG_DEBUG" = "1" ]; then
|
||||||
|
echo "[DEBUG] $*"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
log_info() {
|
||||||
|
echo "[INFO] $*"
|
||||||
|
}
|
||||||
|
|
||||||
ARCH=$1
|
ARCH=$1
|
||||||
PUSH_IMAGES=${2:-false}
|
PUSH_IMAGES=${2:-false}
|
||||||
|
|
||||||
# List of allowed architectures
|
# List of allowed architectures
|
||||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu" "rocm")
|
||||||
|
|
||||||
# Check if ARCH is in the allowed list
|
# Check if ARCH is in the allowed list
|
||||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check if GITHUB_TOKEN is set and not empty
|
# Check if GITHUB_TOKEN is set and not empty
|
||||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
|
||||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
log_info "Error: GITHUB_TOKEN is not set or is empty."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
||||||
|
# variable, this permits testing with forked llama.cpp repositories
|
||||||
|
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
||||||
|
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
|
||||||
|
|
||||||
|
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
||||||
|
# to enable easy container builds on forked repos
|
||||||
|
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
||||||
|
|
||||||
# the most recent llama-swap tag
|
# the most recent llama-swap tag
|
||||||
# have to strip out the 'v' due to .tar.gz file naming
|
# have to strip out the 'v' due to .tar.gz file naming
|
||||||
LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//')
|
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||||
|
|
||||||
|
# Fetches the most recent llama.cpp tag matching the given prefix
|
||||||
|
# Handles pagination to search beyond the first 100 results
|
||||||
|
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
|
||||||
|
# Returns: the version number extracted from the tag
|
||||||
|
fetch_llama_tag() {
|
||||||
|
local tag_prefix=$1
|
||||||
|
local page=1
|
||||||
|
local per_page=100
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
log_debug "Fetching page $page for tag prefix: $tag_prefix"
|
||||||
|
|
||||||
|
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||||
|
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
|
||||||
|
|
||||||
|
# Check for API errors
|
||||||
|
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
|
||||||
|
local error_msg=$(echo "$response" | jq -r '.message')
|
||||||
|
log_info "GitHub API error: $error_msg"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if response is empty array (no more pages)
|
||||||
|
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
|
||||||
|
log_debug "No more pages (empty response)"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract matching tag from this page
|
||||||
|
local found_tag=$(echo "$response" | jq -r \
|
||||||
|
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
|
||||||
|
| sort -r | head -n1)
|
||||||
|
|
||||||
|
if [ -n "$found_tag" ]; then
|
||||||
|
log_debug "Found tag: $found_tag on page $page"
|
||||||
|
echo "$found_tag" | awk -F '-' '{print $NF}'
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
page=$((page + 1))
|
||||||
|
|
||||||
|
# Safety limit to prevent infinite loops
|
||||||
|
if [ $page -gt 50 ]; then
|
||||||
|
log_info "Reached pagination safety limit (50 pages)"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
if [ "$ARCH" == "cpu" ]; then
|
if [ "$ARCH" == "cpu" ]; then
|
||||||
# cpu only containers just use the latest available
|
LCPP_TAG=$(fetch_llama_tag "server")
|
||||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu"
|
BASE_TAG=server-${LCPP_TAG}
|
||||||
echo "Building ${CONTAINER_LATEST} $LS_VER"
|
|
||||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} .
|
|
||||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
|
||||||
docker push ${CONTAINER_LATEST}
|
|
||||||
fi
|
|
||||||
else
|
else
|
||||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
|
||||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
fi
|
||||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
|
||||||
|
|
||||||
# Abort if LCPP_TAG is empty.
|
SD_TAG=master-${ARCH}
|
||||||
if [[ -z "$LCPP_TAG" ]]; then
|
|
||||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
# Abort if LCPP_TAG is empty.
|
||||||
CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}"
|
if [[ -z "$LCPP_TAG" ]]; then
|
||||||
echo "Building ${CONTAINER_TAG} $LS_VER"
|
log_info "Abort: Could not find llama-server container for arch: $ARCH"
|
||||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} .
|
exit 1
|
||||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
else
|
||||||
docker push ${CONTAINER_TAG}
|
log_info "LCPP_TAG: $LCPP_TAG"
|
||||||
docker push ${CONTAINER_LATEST}
|
fi
|
||||||
fi
|
|
||||||
fi
|
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
|
||||||
|
log_info "Abort: DEBUG_ABORT_BUILD set"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
for CONTAINER_TYPE in non-root root; do
|
||||||
|
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||||
|
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
|
||||||
|
USER_UID=0
|
||||||
|
USER_GID=0
|
||||||
|
USER_HOME=/root
|
||||||
|
|
||||||
|
if [ "$CONTAINER_TYPE" == "non-root" ]; then
|
||||||
|
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
|
||||||
|
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
|
||||||
|
USER_UID=10001
|
||||||
|
USER_GID=10001
|
||||||
|
USER_HOME=/app
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||||
|
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||||
|
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||||
|
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||||
|
|
||||||
|
# For architectures with stable-diffusion.cpp support, layer sd-server on top
|
||||||
|
case "$ARCH" in
|
||||||
|
"musa" | "vulkan")
|
||||||
|
log_info "Adding sd-server to $CONTAINER_TAG"
|
||||||
|
docker build -f llama-swap-sd.Containerfile \
|
||||||
|
--build-arg BASE=${CONTAINER_TAG} \
|
||||||
|
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
|
||||||
|
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
|
||||||
|
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||||
|
docker push ${CONTAINER_TAG}
|
||||||
|
docker push ${CONTAINER_LATEST}
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
healthCheckTimeout: 300
|
healthCheckTimeout: 300
|
||||||
logRequests: true
|
logRequests: true
|
||||||
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
models:
|
models:
|
||||||
"qwen2.5":
|
"qwen2.5":
|
||||||
@@ -14,4 +15,19 @@ models:
|
|||||||
cmd: >
|
cmd: >
|
||||||
/app/llama-server
|
/app/llama-server
|
||||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
--port 9999
|
--port 9999
|
||||||
|
|
||||||
|
z-image:
|
||||||
|
checkEndpoint: /
|
||||||
|
cmd: |
|
||||||
|
/app/sd-server
|
||||||
|
--listen-port 9999
|
||||||
|
--diffusion-fa
|
||||||
|
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||||
|
--vae /models/ae.safetensors
|
||||||
|
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||||
|
--offload-to-cpu
|
||||||
|
--cfg-scale 1.0
|
||||||
|
--height 512 --width 512
|
||||||
|
--steps 8
|
||||||
|
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
|
||||||
|
ARG SD_TAG=master-vulkan
|
||||||
|
ARG BASE=llama-swap:latest
|
||||||
|
|
||||||
|
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
|
||||||
|
FROM ${BASE}
|
||||||
|
|
||||||
|
ARG UID=10001
|
||||||
|
ARG GID=10001
|
||||||
|
|
||||||
|
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
|
||||||
@@ -1,16 +1,44 @@
|
|||||||
|
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||||
ARG BASE_TAG=server-cuda
|
ARG BASE_TAG=server-cuda
|
||||||
FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG}
|
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||||
|
|
||||||
# has to be after the FROM
|
# has to be after the FROM
|
||||||
ARG LS_VER=89
|
ARG LS_VER=170
|
||||||
|
ARG LS_REPO=mostlygeek/llama-swap
|
||||||
|
|
||||||
|
# Set default UID/GID arguments
|
||||||
|
ARG UID=10001
|
||||||
|
ARG GID=10001
|
||||||
|
ARG USER_HOME=/app
|
||||||
|
|
||||||
|
# Add user/group
|
||||||
|
ENV HOME=$USER_HOME
|
||||||
|
RUN if [ $UID -ne 0 ]; then \
|
||||||
|
if [ $GID -ne 0 ]; then \
|
||||||
|
groupadd --system --gid $GID app; \
|
||||||
|
fi; \
|
||||||
|
useradd --system --uid $UID --gid $GID \
|
||||||
|
--home $USER_HOME app; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Handle paths
|
||||||
|
RUN mkdir --parents $HOME /app
|
||||||
|
RUN chown --recursive $UID:$GID $HOME /app
|
||||||
|
|
||||||
|
# Switch user
|
||||||
|
USER $UID:$GID
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
RUN \
|
|
||||||
curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
|
||||||
tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \
|
|
||||||
rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz
|
|
||||||
|
|
||||||
COPY config.example.yaml /app/config.yaml
|
# Add /app to PATH
|
||||||
|
ENV PATH="/app:${PATH}"
|
||||||
|
|
||||||
|
RUN \
|
||||||
|
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||||
|
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||||
|
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
|
||||||
|
|
||||||
|
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||||
|
|
||||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 261 KiB After Width: | Height: | Size: 261 KiB |
|
Before Width: | Height: | Size: 351 KiB After Width: | Height: | Size: 351 KiB |
|
After Width: | Height: | Size: 198 KiB |
@@ -0,0 +1,467 @@
|
|||||||
|
# config.yaml
|
||||||
|
|
||||||
|
llama-swap is designed to be very simple: one binary, one configuration file.
|
||||||
|
|
||||||
|
## minimal viable config
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
This is enough to launch `llama-server` to serve `model1`. Of course, llama-swap is about making it possible to serve many models:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: llama-server --port ${PORT} -m /path/to/model.gguf
|
||||||
|
model2:
|
||||||
|
cmd: llama-server --port ${PORT} -m /path/to/another_model.gguf
|
||||||
|
model3:
|
||||||
|
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model. Useful if you want to run multiple models at the same time with the `groups` feature.
|
||||||
|
|
||||||
|
## Advanced control with `cmd`
|
||||||
|
|
||||||
|
llama-swap is also about customizability. You can use any CLI flag available:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: | # support for multi-line
|
||||||
|
llama-server --PORT ${PORT} -m /path/to/model.gguf
|
||||||
|
--ctx-size 8192
|
||||||
|
--jinja
|
||||||
|
--cache-type-k q8_0
|
||||||
|
--cache-type-v q8_0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Support for any OpenAI API compatible server
|
||||||
|
|
||||||
|
llama-swap supports any OpenAI API compatible server. If you can run it on the CLI llama-swap will be able to manage it. Even if it's run in Docker or Podman containers.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
"Q3-30B-CODER-VLLM":
|
||||||
|
name: "Qwen3 30B Coder vllm AWQ (Q3-30B-CODER-VLLM)"
|
||||||
|
# cmdStop provides a reliable way to stop containers
|
||||||
|
cmdStop: docker stop vllm-coder
|
||||||
|
cmd: |
|
||||||
|
docker run --init --rm --name vllm-coder
|
||||||
|
--runtime=nvidia --gpus '"device=2,3"'
|
||||||
|
--shm-size=16g
|
||||||
|
-v /mnt/nvme/vllm-cache:/root/.cache
|
||||||
|
-v /mnt/ssd-extra/models:/models -p ${PORT}:8000
|
||||||
|
vllm/vllm-openai:v0.10.0
|
||||||
|
--model "/models/cpatonn/Qwen3-Coder-30B-A3B-Instruct-AWQ"
|
||||||
|
--served-model-name "Q3-30B-CODER-VLLM"
|
||||||
|
--enable-expert-parallel
|
||||||
|
--swap-space 16
|
||||||
|
--max-num-seqs 512
|
||||||
|
--max-model-len 65536
|
||||||
|
--max-seq-len-to-capture 65536
|
||||||
|
--gpu-memory-utilization 0.9
|
||||||
|
--tensor-parallel-size 2
|
||||||
|
--trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
|
## Many more features..
|
||||||
|
|
||||||
|
llama-swap supports many more features to customize how you want to manage your environment.
|
||||||
|
|
||||||
|
| Feature | Description |
|
||||||
|
| --------- | ---------------------------------------------- |
|
||||||
|
| `ttl` | automatic unloading of models after a timeout |
|
||||||
|
| `macros` | reusable snippets to use in configurations |
|
||||||
|
| `groups` | run multiple models at a time |
|
||||||
|
| `hooks` | event driven functionality |
|
||||||
|
| `env` | define environment variables per model |
|
||||||
|
| `aliases` | serve a model with different names |
|
||||||
|
| `filters` | modify requests before sending to the upstream |
|
||||||
|
| `...` | And many more tweaks |
|
||||||
|
|
||||||
|
## Full Configuration Example
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# add this modeline for validation in vscode
|
||||||
|
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||||
|
#
|
||||||
|
# llama-swap YAML configuration example
|
||||||
|
# -------------------------------------
|
||||||
|
#
|
||||||
|
# 💡 Tip - Use an LLM with this file!
|
||||||
|
# ====================================
|
||||||
|
# This example configuration is written to be LLM friendly. Try
|
||||||
|
# copying this file into an LLM and asking it to explain or generate
|
||||||
|
# sections for you.
|
||||||
|
# ====================================
|
||||||
|
|
||||||
|
# Usage notes:
|
||||||
|
# - Below are all the available configuration options for llama-swap.
|
||||||
|
# - Settings noted as "required" must be in your configuration file
|
||||||
|
# - Settings noted as "optional" can be omitted
|
||||||
|
|
||||||
|
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||||
|
# - optional, default: 120
|
||||||
|
# - minimum value is 15 seconds, anything less will be set to this value
|
||||||
|
healthCheckTimeout: 500
|
||||||
|
|
||||||
|
# logLevel: sets the logging value
|
||||||
|
# - optional, default: info
|
||||||
|
# - Valid log levels: debug, info, warn, error
|
||||||
|
logLevel: info
|
||||||
|
|
||||||
|
# logTimeFormat: enables and sets the logging timestamp format
|
||||||
|
# - optional, default (disabled): ""
|
||||||
|
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||||
|
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||||
|
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||||
|
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||||
|
logTimeFormat: ""
|
||||||
|
|
||||||
|
# logToStdout: controls what is logged to stdout
|
||||||
|
# - optional, default: "proxy"
|
||||||
|
# - valid values:
|
||||||
|
# - "proxy": logs generated by llama-swap when swapping models,
|
||||||
|
# handling requests, etc.
|
||||||
|
# - "upstream": a copy of an upstream processes stdout logs
|
||||||
|
# - "both": both the proxy and upstream logs interleaved together
|
||||||
|
# - "none": no logs are ever written to stdout
|
||||||
|
logToStdout: "proxy"
|
||||||
|
|
||||||
|
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||||
|
# - optional, default: 1000
|
||||||
|
# - controls how many metrics are stored in memory before older ones are discarded
|
||||||
|
# - useful for limiting memory usage when processing large volumes of metrics
|
||||||
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
|
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||||
|
# - optional, default: 5800
|
||||||
|
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||||
|
# - it is automatically incremented for every model that uses it
|
||||||
|
startPort: 10001
|
||||||
|
|
||||||
|
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||||
|
# field
|
||||||
|
# - optional, default: false
|
||||||
|
# - when true, a stream of loading messages will be sent to the client in the
|
||||||
|
# reasoning field so chat UIs can show that loading is in progress.
|
||||||
|
# - see #366 for more details
|
||||||
|
sendLoadingState: true
|
||||||
|
|
||||||
|
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||||
|
# - optional, default: false
|
||||||
|
# - when true, model aliases will be output to the API model listing duplicating
|
||||||
|
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||||
|
includeAliasesInList: false
|
||||||
|
|
||||||
|
# apiKeys: require an API key when making requests to inference endpoints
|
||||||
|
# - optional, default: []
|
||||||
|
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||||
|
# - each key is a non-empty string
|
||||||
|
apiKeys:
|
||||||
|
- "sk-hunter2"
|
||||||
|
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||||
|
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||||
|
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
|
||||||
|
|
||||||
|
# macros: a dictionary of string substitutions
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - macros are reusable snippets
|
||||||
|
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||||
|
# - useful for reducing common configuration settings
|
||||||
|
# - macro names are strings and must be less than 64 characters
|
||||||
|
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||||
|
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||||
|
# - macro values can be numbers, bools, or strings
|
||||||
|
# - macros can contain other macros, but they must be defined before they are used
|
||||||
|
macros:
|
||||||
|
# Example of a multi-line macro
|
||||||
|
"latest-llama": >
|
||||||
|
/path/to/llama-server/llama-server-ec9e0301
|
||||||
|
--port ${PORT}
|
||||||
|
|
||||||
|
"default_ctx": 4096
|
||||||
|
|
||||||
|
# Example of macro-in-macro usage. macros can contain other macros
|
||||||
|
# but they must be previously declared.
|
||||||
|
"default_args": "--ctx-size ${default_ctx}"
|
||||||
|
|
||||||
|
# models: a dictionary of model configurations
|
||||||
|
# - required
|
||||||
|
# - each key is the model's ID, used in API requests
|
||||||
|
# - model settings have default values that are used if they are not defined here
|
||||||
|
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||||
|
# - below are examples of the all the settings a model can have
|
||||||
|
models:
|
||||||
|
# keys are the model names used in API requests
|
||||||
|
"llama":
|
||||||
|
# macros: a dictionary of string substitutions specific to this model
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - macros defined here override macros defined in the global macros section
|
||||||
|
# - model level macros follow the same rules as global macros
|
||||||
|
macros:
|
||||||
|
"default_ctx": 16384
|
||||||
|
"temp": 0.7
|
||||||
|
|
||||||
|
# cmd: the command to run to start the inference server.
|
||||||
|
# - required
|
||||||
|
# - it is just a string, similar to what you would run on the CLI
|
||||||
|
# - using `|` allows for comments in the command, these will be parsed out
|
||||||
|
# - macros can be used within cmd
|
||||||
|
cmd: |
|
||||||
|
# ${latest-llama} is a macro that is defined above
|
||||||
|
${latest-llama}
|
||||||
|
--model path/to/llama-8B-Q4_K_M.gguf
|
||||||
|
--ctx-size ${default_ctx}
|
||||||
|
--temperature ${temp}
|
||||||
|
|
||||||
|
# name: a display name for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
name: "llama 3.1 8B"
|
||||||
|
|
||||||
|
# description: a description for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
description: "A small but capable model used for quick testing"
|
||||||
|
|
||||||
|
# env: define an array of environment variables to inject into cmd's environment
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - each value is a single string
|
||||||
|
# - in the format: ENV_NAME=value
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||||
|
|
||||||
|
# proxy: the URL where llama-swap routes API requests
|
||||||
|
# - optional, default: http://localhost:${PORT}
|
||||||
|
# - if you used ${PORT} in cmd this can be omitted
|
||||||
|
# - if you use a custom port in cmd this *must* be set
|
||||||
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
|
# aliases: alternative model names that this model configuration is used for
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - aliases must be unique globally
|
||||||
|
# - useful for impersonating a specific model
|
||||||
|
aliases:
|
||||||
|
- "gpt-4o-mini"
|
||||||
|
- "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
# checkEndpoint: URL path to check if the server is ready
|
||||||
|
# - optional, default: /health
|
||||||
|
# - endpoint is expected to return an HTTP 200 response
|
||||||
|
# - all requests wait until the endpoint is ready or fails
|
||||||
|
# - use "none" to skip endpoint health checking
|
||||||
|
checkEndpoint: /custom-endpoint
|
||||||
|
|
||||||
|
# ttl: automatically unload the model after ttl seconds
|
||||||
|
# - optional, default: 0
|
||||||
|
# - ttl values must be a value greater than 0
|
||||||
|
# - a value of 0 disables automatic unloading of the model
|
||||||
|
ttl: 60
|
||||||
|
|
||||||
|
# useModelName: override the model name that is sent to upstream server
|
||||||
|
# - optional, default: ""
|
||||||
|
# - useful for when the upstream server expects a specific model name that
|
||||||
|
# is different from the model's ID
|
||||||
|
useModelName: "qwen:qwq"
|
||||||
|
|
||||||
|
# filters: a dictionary of filter settings
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - only stripParams is currently supported
|
||||||
|
filters:
|
||||||
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
|
# - optional, default: ""
|
||||||
|
# - useful for server side enforcement of sampling parameters
|
||||||
|
# - the `model` parameter can never be removed
|
||||||
|
# - can be any JSON key in the request body
|
||||||
|
# - recommended to stick to sampling parameters
|
||||||
|
stripParams: "temperature, top_p, top_k"
|
||||||
|
|
||||||
|
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - while metadata can contains complex types it is recommended to keep it simple
|
||||||
|
# - metadata is only passed through in /v1/models responses
|
||||||
|
metadata:
|
||||||
|
# port will remain an integer
|
||||||
|
port: ${PORT}
|
||||||
|
|
||||||
|
# the ${temp} macro will remain a float
|
||||||
|
temperature: ${temp}
|
||||||
|
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||||
|
|
||||||
|
a_list:
|
||||||
|
- 1
|
||||||
|
- 1.23
|
||||||
|
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||||
|
|
||||||
|
an_obj:
|
||||||
|
a: "1"
|
||||||
|
b: 2
|
||||||
|
# objects can contain complex types with macro substitution
|
||||||
|
# becomes: c: [0.7, false, "model: llama"]
|
||||||
|
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||||
|
|
||||||
|
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||||
|
# - optional, default: 0
|
||||||
|
# - useful for limiting the number of active parallel requests a model can process
|
||||||
|
# - must be set per model
|
||||||
|
# - any number greater than 0 will override the internal default value of 10
|
||||||
|
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||||
|
# - recommended to be omitted and the default used
|
||||||
|
concurrencyLimit: 0
|
||||||
|
|
||||||
|
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||||
|
# - optional, default: undefined (use global setting)
|
||||||
|
sendLoadingState: false
|
||||||
|
|
||||||
|
# Unlisted model example:
|
||||||
|
"qwen-unlisted":
|
||||||
|
# unlisted: boolean, true or false
|
||||||
|
# - optional, default: false
|
||||||
|
# - unlisted models do not show up in /v1/models api requests
|
||||||
|
# - can be requested as normal through all apis
|
||||||
|
unlisted: true
|
||||||
|
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
|
# Docker example:
|
||||||
|
# container runtimes like Docker and Podman can be used reliably with
|
||||||
|
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||||
|
"docker-llama":
|
||||||
|
proxy: "http://127.0.0.1:${PORT}"
|
||||||
|
cmd: |
|
||||||
|
docker run --name ${MODEL_ID}
|
||||||
|
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
|
# cmdStop: command to run to stop the model gracefully
|
||||||
|
# - optional, default: ""
|
||||||
|
# - useful for stopping commands managed by another system
|
||||||
|
# - the upstream's process id is available in the ${PID} macro
|
||||||
|
#
|
||||||
|
# When empty, llama-swap has this default behaviour:
|
||||||
|
# - on POSIX systems: a SIGTERM signal is sent
|
||||||
|
# - on Windows, calls taskkill to stop the process
|
||||||
|
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||||
|
cmdStop: docker stop ${MODEL_ID}
|
||||||
|
|
||||||
|
# groups: a dictionary of group settings
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - provides advanced controls over model swapping behaviour
|
||||||
|
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
||||||
|
# - model IDs must be defined in the Models section
|
||||||
|
# - a model can only be a member of one group
|
||||||
|
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||||
|
# - see issue #109 for details
|
||||||
|
#
|
||||||
|
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||||
|
groups:
|
||||||
|
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
||||||
|
# to run a time across the whole llama-swap instance
|
||||||
|
"group1":
|
||||||
|
# swap: controls the model swapping behaviour in within the group
|
||||||
|
# - optional, default: true
|
||||||
|
# - true : only one model is allowed to run at a time
|
||||||
|
# - false: all models can run together, no swapping
|
||||||
|
swap: true
|
||||||
|
|
||||||
|
# exclusive: controls how the group affects other groups
|
||||||
|
# - optional, default: true
|
||||||
|
# - true: causes all other groups to unload when this group runs a model
|
||||||
|
# - false: does not affect other groups
|
||||||
|
exclusive: true
|
||||||
|
|
||||||
|
# members references the models defined above
|
||||||
|
# required
|
||||||
|
members:
|
||||||
|
- "llama"
|
||||||
|
- "qwen-unlisted"
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# - in group2 all models can run at the same time
|
||||||
|
# - when a different group is loaded it causes all running models in this group to unload
|
||||||
|
"group2":
|
||||||
|
swap: false
|
||||||
|
|
||||||
|
# exclusive: false does not unload other groups when a model in group2 is requested
|
||||||
|
# - the models in group2 will be loaded but will not unload any other groups
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "docker-llama"
|
||||||
|
- "modelA"
|
||||||
|
- "modelB"
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# - a persistent group, prevents other groups from unloading it
|
||||||
|
"forever":
|
||||||
|
# persistent: prevents over groups from unloading the models in this group
|
||||||
|
# - optional, default: false
|
||||||
|
# - does not affect individual model behaviour
|
||||||
|
persistent: true
|
||||||
|
|
||||||
|
# set swap/exclusive to false to prevent swapping inside the group
|
||||||
|
# and the unloading of other groups
|
||||||
|
swap: false
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "forever-modelA"
|
||||||
|
- "forever-modelB"
|
||||||
|
- "forever-modelc"
|
||||||
|
|
||||||
|
# hooks: a dictionary of event triggers and actions
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - the only supported hook is on_startup
|
||||||
|
hooks:
|
||||||
|
# on_startup: a dictionary of actions to perform on startup
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - the only supported action is preload
|
||||||
|
on_startup:
|
||||||
|
# preload: a list of model ids to load on startup
|
||||||
|
# - optional, default: empty list
|
||||||
|
# - model names must match keys in the models sections
|
||||||
|
# - when preloading multiple models at once, define a group
|
||||||
|
# otherwise models will be loaded and swapped out
|
||||||
|
preload:
|
||||||
|
- "llama"
|
||||||
|
|
||||||
|
# peers: a dictionary of remote peers and models they provide
|
||||||
|
# - optional, default empty dictionary
|
||||||
|
# - peers can be another llama-swap
|
||||||
|
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||||
|
peers:
|
||||||
|
# keys is the peer'd ID
|
||||||
|
llama-swap-peer:
|
||||||
|
# proxy: a valid base URL to proxy requests to
|
||||||
|
# - required
|
||||||
|
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||||
|
proxy: http://192.168.1.23
|
||||||
|
# models: a list of models served by the peer
|
||||||
|
# - required
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
- model_b
|
||||||
|
- embeddings/model_c
|
||||||
|
openrouter:
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
# apiKey: a string key to be injected into the request
|
||||||
|
# - optional, default: ""
|
||||||
|
# - if blank, no key will be added to the request
|
||||||
|
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||||
|
apiKey: sk-your-openrouter-key
|
||||||
|
models:
|
||||||
|
- meta-llama/llama-3.1-8b-instruct
|
||||||
|
- qwen/qwen3-235b-a22b-2507
|
||||||
|
- deepseek/deepseek-v3.2
|
||||||
|
- z-ai/glm-4.7
|
||||||
|
- moonshotai/kimi-k2-0905
|
||||||
|
- minimax/minimax-m2.1
|
||||||
|
```
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
## Container Security
|
||||||
|
|
||||||
|
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
|
||||||
|
|
||||||
|
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
|
||||||
|
|
||||||
|
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
|
||||||
|
|
||||||
|
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
|
||||||
|
|
||||||
|
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default initializes a default in-process dispatcher
|
||||||
|
var Default = NewDispatcherConfig(25000)
|
||||||
|
|
||||||
|
// On subscribes to an event, the type of the event will be automatically
|
||||||
|
// inferred from the provided type. Must be constant for this to work. This
|
||||||
|
// functions same way as Subscribe() but uses the default dispatcher instead.
|
||||||
|
func On[T Event](handler func(T)) context.CancelFunc {
|
||||||
|
return Subscribe(Default, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnType subscribes to an event with the specified event type. This functions
|
||||||
|
// same way as SubscribeTo() but uses the default dispatcher instead.
|
||||||
|
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
|
||||||
|
return SubscribeTo(Default, eventType, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit writes an event into the dispatcher. This functions same way as
|
||||||
|
// Publish() but uses the default dispatcher instead.
|
||||||
|
func Emit[T Event](ev T) {
|
||||||
|
Publish(Default, ev)
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
|
||||||
|
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
|
||||||
|
*/
|
||||||
|
func BenchmarkSubscribeConcurrent(b *testing.B) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
unsub := Subscribe(d, func(ev MyEvent1) {})
|
||||||
|
unsub()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultPublish(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
var count int64
|
||||||
|
defer On(func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&count, 1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
defer OnType(TypeEvent1, func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&count, 1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Publish
|
||||||
|
wg.Add(4)
|
||||||
|
Emit(MyEvent1{})
|
||||||
|
Emit(MyEvent1{})
|
||||||
|
|
||||||
|
// Wait and check
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, int64(4), count)
|
||||||
|
}
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for details.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event represents an event contract
|
||||||
|
type Event interface {
|
||||||
|
Type() uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// registry holds an immutable sorted array of event mappings
|
||||||
|
type registry struct {
|
||||||
|
keys []uint32 // Event types (sorted)
|
||||||
|
grps []any // Corresponding subscribers
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Dispatcher -------------------------------------
|
||||||
|
|
||||||
|
// Dispatcher represents an event dispatcher.
|
||||||
|
type Dispatcher struct {
|
||||||
|
subs atomic.Pointer[registry] // Atomic pointer to immutable array
|
||||||
|
done chan struct{} // Cancellation
|
||||||
|
maxQueue int // Maximum queue size per consumer
|
||||||
|
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDispatcher creates a new dispatcher of events.
|
||||||
|
func NewDispatcher() *Dispatcher {
|
||||||
|
return NewDispatcherConfig(50000)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
|
||||||
|
func NewDispatcherConfig(maxQueue int) *Dispatcher {
|
||||||
|
d := &Dispatcher{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
maxQueue: maxQueue,
|
||||||
|
}
|
||||||
|
|
||||||
|
d.subs.Store(®istry{
|
||||||
|
keys: make([]uint32, 0, 16),
|
||||||
|
grps: make([]any, 0, 16),
|
||||||
|
})
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the dispatcher
|
||||||
|
func (d *Dispatcher) Close() error {
|
||||||
|
close(d.done)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClosed returns whether the dispatcher is closed or not
|
||||||
|
func (d *Dispatcher) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-d.done:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findGroup performs a lock-free binary search for the event type
|
||||||
|
func (d *Dispatcher) findGroup(eventType uint32) any {
|
||||||
|
reg := d.subs.Load()
|
||||||
|
keys := reg.keys
|
||||||
|
|
||||||
|
// Inlined binary search for better cache locality
|
||||||
|
left, right := 0, len(keys)
|
||||||
|
for left < right {
|
||||||
|
mid := left + (right-left)/2
|
||||||
|
if keys[mid] < eventType {
|
||||||
|
left = mid + 1
|
||||||
|
} else {
|
||||||
|
right = mid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if left < len(keys) && keys[left] == eventType {
|
||||||
|
return reg.grps[left]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes to an event, the type of the event will be automatically
|
||||||
|
// inferred from the provided type. Must be constant for this to work.
|
||||||
|
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
|
||||||
|
var event T
|
||||||
|
return SubscribeTo(broker, event.Type(), handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscribeTo subscribes to an event with the specified event type.
|
||||||
|
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
|
||||||
|
if broker.isClosed() {
|
||||||
|
panic(errClosed)
|
||||||
|
}
|
||||||
|
|
||||||
|
broker.mu.Lock()
|
||||||
|
defer broker.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if group already exists
|
||||||
|
if existing := broker.findGroup(eventType); existing != nil {
|
||||||
|
grp := groupOf[T](eventType, existing)
|
||||||
|
sub := grp.Add(handler)
|
||||||
|
return func() {
|
||||||
|
grp.Del(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new group
|
||||||
|
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
|
||||||
|
sub := grp.Add(handler)
|
||||||
|
|
||||||
|
// Copy-on-write: insert new entry in sorted position
|
||||||
|
old := broker.subs.Load()
|
||||||
|
idx := sort.Search(len(old.keys), func(i int) bool {
|
||||||
|
return old.keys[i] >= eventType
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create new arrays with space for one more element
|
||||||
|
newKeys := make([]uint32, len(old.keys)+1)
|
||||||
|
newGrps := make([]any, len(old.grps)+1)
|
||||||
|
|
||||||
|
// Copy elements before insertion point
|
||||||
|
copy(newKeys[:idx], old.keys[:idx])
|
||||||
|
copy(newGrps[:idx], old.grps[:idx])
|
||||||
|
|
||||||
|
// Insert new element
|
||||||
|
newKeys[idx] = eventType
|
||||||
|
newGrps[idx] = grp
|
||||||
|
|
||||||
|
// Copy elements after insertion point
|
||||||
|
copy(newKeys[idx+1:], old.keys[idx:])
|
||||||
|
copy(newGrps[idx+1:], old.grps[idx:])
|
||||||
|
|
||||||
|
// Atomically store the new registry (mutex ensures no concurrent writers)
|
||||||
|
newReg := ®istry{keys: newKeys, grps: newGrps}
|
||||||
|
broker.subs.Store(newReg)
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
grp.Del(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish writes an event into the dispatcher
|
||||||
|
func Publish[T Event](broker *Dispatcher, ev T) {
|
||||||
|
eventType := ev.Type()
|
||||||
|
if sub := broker.findGroup(eventType); sub != nil {
|
||||||
|
group := groupOf[T](eventType, sub)
|
||||||
|
group.Broadcast(ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count counts the number of subscribers, this is for testing only.
|
||||||
|
func (d *Dispatcher) count(eventType uint32) int {
|
||||||
|
if group := d.findGroup(eventType); group != nil {
|
||||||
|
return group.(interface{ Count() int }).Count()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupOf casts the subscriber group to the specified generic type
|
||||||
|
func groupOf[T Event](eventType uint32, subs any) *group[T] {
|
||||||
|
if group, ok := subs.(*group[T]); ok {
|
||||||
|
return group
|
||||||
|
}
|
||||||
|
|
||||||
|
panic(errConflict[T](eventType, subs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Subscriber -------------------------------------
|
||||||
|
|
||||||
|
// consumer represents a consumer with a message queue
|
||||||
|
type consumer[T Event] struct {
|
||||||
|
queue []T // Current work queue
|
||||||
|
stop bool // Stop signal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen listens to the event queue and processes events
|
||||||
|
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
|
||||||
|
pending := make([]T, 0, 128)
|
||||||
|
|
||||||
|
for {
|
||||||
|
c.L.Lock()
|
||||||
|
for len(s.queue) == 0 {
|
||||||
|
switch {
|
||||||
|
case s.stop:
|
||||||
|
c.L.Unlock()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap buffers and reset the current queue
|
||||||
|
temp := s.queue
|
||||||
|
s.queue = pending[:0]
|
||||||
|
pending = temp
|
||||||
|
c.L.Unlock()
|
||||||
|
|
||||||
|
// Outside of the critical section, process the work
|
||||||
|
for _, event := range pending {
|
||||||
|
fn(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify potential publishers waiting due to backpressure
|
||||||
|
c.Broadcast()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Subscriber Group -------------------------------------
|
||||||
|
|
||||||
|
// group represents a consumer group
|
||||||
|
type group[T Event] struct {
|
||||||
|
cond *sync.Cond
|
||||||
|
subs []*consumer[T]
|
||||||
|
maxQueue int // Maximum queue size per consumer
|
||||||
|
maxLen int // Current maximum queue length across all consumers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast sends an event to all consumers
|
||||||
|
func (s *group[T]) Broadcast(ev T) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Calculate current maximum queue length
|
||||||
|
s.maxLen = 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
if len(sub.queue) > s.maxLen {
|
||||||
|
s.maxLen = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backpressure: wait if queues are full
|
||||||
|
for s.maxLen >= s.maxQueue {
|
||||||
|
s.cond.Wait()
|
||||||
|
|
||||||
|
// Recalculate after wakeup
|
||||||
|
s.maxLen = 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
if len(sub.queue) > s.maxLen {
|
||||||
|
s.maxLen = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add event to all queues and track new maximum
|
||||||
|
newMax := 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
sub.queue = append(sub.queue, ev)
|
||||||
|
if len(sub.queue) > newMax {
|
||||||
|
newMax = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.maxLen = newMax
|
||||||
|
s.cond.Broadcast() // Wake consumers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a subscriber to the list
|
||||||
|
func (s *group[T]) Add(handler func(T)) *consumer[T] {
|
||||||
|
sub := &consumer[T]{
|
||||||
|
queue: make([]T, 0, 64),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the consumer to the list of active consumers
|
||||||
|
s.cond.L.Lock()
|
||||||
|
s.subs = append(s.subs, sub)
|
||||||
|
s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Start listening
|
||||||
|
go sub.Listen(s.cond, handler)
|
||||||
|
return sub
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del removes a subscriber from the list
|
||||||
|
func (s *group[T]) Del(sub *consumer[T]) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Search and remove the subscriber
|
||||||
|
sub.stop = true
|
||||||
|
for i, v := range s.subs {
|
||||||
|
if v == sub {
|
||||||
|
copy(s.subs[i:], s.subs[i+1:])
|
||||||
|
s.subs = s.subs[:len(s.subs)-1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Debugging -------------------------------------
|
||||||
|
|
||||||
|
var errClosed = fmt.Errorf("event dispatcher is closed")
|
||||||
|
|
||||||
|
// Count returns the number of subscribers in this group
|
||||||
|
func (s *group[T]) Count() int {
|
||||||
|
return len(s.subs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation of the type
|
||||||
|
func (s *group[T]) String() string {
|
||||||
|
typ := reflect.TypeOf(s).String()
|
||||||
|
idx := strings.LastIndex(typ, "/")
|
||||||
|
typ = typ[idx+1 : len(typ)-1]
|
||||||
|
return typ
|
||||||
|
}
|
||||||
|
|
||||||
|
// errConflict returns a conflict message
|
||||||
|
func errConflict[T any](eventType uint32, existing any) string {
|
||||||
|
var want T
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
|
||||||
|
want, existing, eventType,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPublish(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Subscribe, must be received in order
|
||||||
|
var count int64
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Publish
|
||||||
|
wg.Add(3)
|
||||||
|
Publish(d, MyEvent1{Number: 1})
|
||||||
|
Publish(d, MyEvent1{Number: 2})
|
||||||
|
Publish(d, MyEvent1{Number: 3})
|
||||||
|
|
||||||
|
// Wait and check
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, int64(3), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnsubscribe(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||||
|
unsubscribe := Subscribe(d, func(ev MyEvent1) {
|
||||||
|
// Nothing
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, 1, d.count(TypeEvent1))
|
||||||
|
unsubscribe()
|
||||||
|
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrent(t *testing.T) {
|
||||||
|
const max = 1000000
|
||||||
|
var count int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
if current := atomic.AddInt64(&count, 1); current == max {
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Asynchronously publish
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < max; i++ {
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
// Subscriber that does nothing
|
||||||
|
})()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, max, int(count))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscribeDifferentType(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublishDifferentType(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseDispatcher(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
|
||||||
|
|
||||||
|
assert.NoError(t, d.Close())
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix(t *testing.T) {
|
||||||
|
const amount = 1000
|
||||||
|
for _, subs := range []int{1, 10, 100} {
|
||||||
|
for _, topics := range []int{1, 10} {
|
||||||
|
expected := subs * topics * amount
|
||||||
|
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
|
||||||
|
var count atomic.Int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(expected)
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
for i := 0; i < subs; i++ {
|
||||||
|
for id := 0; id < topics; id++ {
|
||||||
|
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
|
||||||
|
count.Add(1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < amount; n++ {
|
||||||
|
for id := 0; id < topics; id++ {
|
||||||
|
go Publish(d, MyEvent3{ID: id})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, expected, int(count.Load()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentSubscriptionRace(t *testing.T) {
|
||||||
|
// This test specifically targets the race condition that occurs when multiple
|
||||||
|
// goroutines try to subscribe to different event types simultaneously.
|
||||||
|
// Without the CAS loop, subscriptions could be lost due to registry corruption.
|
||||||
|
|
||||||
|
const numGoroutines = 100
|
||||||
|
const numEventTypes = 50
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer d.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var receivedCount int64
|
||||||
|
var subscribedTypes sync.Map // Thread-safe map
|
||||||
|
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
// Start multiple goroutines that subscribe to different event types concurrently
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(goroutineID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Each goroutine subscribes to a unique event type
|
||||||
|
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
|
||||||
|
|
||||||
|
// Subscribe to the event type
|
||||||
|
SubscribeTo(d, eventType, func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(&receivedCount, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Record that this type was subscribed
|
||||||
|
subscribedTypes.Store(eventType, true)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all subscriptions to complete
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Count the number of unique event types subscribed
|
||||||
|
expectedTypes := 0
|
||||||
|
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||||
|
expectedTypes++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Small delay to ensure all subscriptions are fully processed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Publish events to each subscribed type
|
||||||
|
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||||
|
eventType := key.(uint32)
|
||||||
|
Publish(d, MyEvent3{ID: int(eventType)})
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wait for all events to be processed
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify that we received at least the expected number of events
|
||||||
|
// (there might be more if multiple goroutines subscribed to the same event type)
|
||||||
|
received := atomic.LoadInt64(&receivedCount)
|
||||||
|
assert.GreaterOrEqual(t, int(received), expectedTypes,
|
||||||
|
"Should have received at least %d events, got %d", expectedTypes, received)
|
||||||
|
|
||||||
|
// Verify that we have the expected number of unique event types
|
||||||
|
assert.Equal(t, numEventTypes, expectedTypes,
|
||||||
|
"Should have exactly %d unique event types", numEventTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentHandlerRegistration(t *testing.T) {
|
||||||
|
const numGoroutines = 100
|
||||||
|
|
||||||
|
// Test concurrent subscriptions to the same event type
|
||||||
|
t.Run("SameEventType", func(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var handlerCount int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Start multiple goroutines subscribing to the same event type (0x1)
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&handlerCount, 1)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Verify all handlers were registered by publishing an event
|
||||||
|
atomic.StoreInt64(&handlerCount, 0)
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
|
||||||
|
// Small delay to ensure all handlers have executed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
|
||||||
|
"Not all handlers were registered due to race condition")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test concurrent subscriptions to different event types
|
||||||
|
t.Run("DifferentEventTypes", func(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
receivedEvents := make(map[uint32]*int64)
|
||||||
|
|
||||||
|
// Create multiple event types and subscribe concurrently
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
eventType := uint32(100 + i)
|
||||||
|
counter := new(int64)
|
||||||
|
receivedEvents[eventType] = counter
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(et uint32, cnt *int64) {
|
||||||
|
defer wg.Done()
|
||||||
|
SubscribeTo(d, et, func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(cnt, 1)
|
||||||
|
})
|
||||||
|
}(eventType, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Publish events to all types
|
||||||
|
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
|
||||||
|
Publish(d, MyEvent3{ID: int(eventType)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small delay to ensure all handlers have executed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify all event types received their events
|
||||||
|
for eventType, counter := range receivedEvents {
|
||||||
|
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
|
||||||
|
"Event type %d did not receive its event", eventType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackpressure(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
d.maxQueue = 10
|
||||||
|
|
||||||
|
var processedCount int64
|
||||||
|
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(&processedCount, 1)
|
||||||
|
})
|
||||||
|
defer unsub()
|
||||||
|
|
||||||
|
const eventsToPublish = 1000
|
||||||
|
for i := 0; i < eventsToPublish; i++ {
|
||||||
|
Publish(d, MyEvent3{ID: 0x200})
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify all events were eventually processed
|
||||||
|
finalProcessed := atomic.LoadInt64(&processedCount)
|
||||||
|
assert.Equal(t, int64(eventsToPublish), finalProcessed)
|
||||||
|
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Test Events -------------------------------------
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeEvent1 = 0x1
|
||||||
|
TypeEvent2 = 0x2
|
||||||
|
)
|
||||||
|
|
||||||
|
type MyEvent1 struct {
|
||||||
|
Number int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
|
||||||
|
|
||||||
|
type MyEvent2 struct {
|
||||||
|
Text string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
|
||||||
|
|
||||||
|
type MyEvent3 struct {
|
||||||
|
ID int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
module github.com/mostlygeek/llama-swap
|
module github.com/mostlygeek/llama-swap
|
||||||
|
|
||||||
go 1.23.0
|
go 1.25.4
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/billziss-gh/golib v0.2.0
|
||||||
github.com/fsnotify/fsnotify v1.9.0
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
@@ -37,9 +37,9 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.36.0 // indirect
|
golang.org/x/crypto v0.45.0 // indirect
|
||||||
golang.org/x/net v0.38.0 // indirect
|
golang.org/x/net v0.47.0 // indirect
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.1 // indirect
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||||
|
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
@@ -30,8 +32,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
|||||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
@@ -80,16 +80,16 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
|||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ import (
|
|||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -26,7 +28,9 @@ var (
|
|||||||
func main() {
|
func main() {
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
configPath := flag.String("config", "config.yaml", "config file name")
|
configPath := flag.String("config", "config.yaml", "config file name")
|
||||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
listenStr := flag.String("listen", "", "listen ip/port")
|
||||||
|
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||||
|
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||||
showVersion := flag.Bool("version", false, "show version of build")
|
showVersion := flag.Bool("version", false, "show version of build")
|
||||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||||
|
|
||||||
@@ -37,13 +41,13 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := proxy.LoadConfig(*configPath)
|
conf, err := config.LoadConfig(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error loading config: %v\n", err)
|
fmt.Printf("Error loading config: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.Profiles) > 0 {
|
if len(conf.Profiles) > 0 {
|
||||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,137 +57,163 @@ func main() {
|
|||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyManager := proxy.New(config)
|
// Validate TLS flags.
|
||||||
|
var useTLS = (*certFile != "" && *keyFile != "")
|
||||||
|
if (*certFile != "" && *keyFile == "") ||
|
||||||
|
(*certFile == "" && *keyFile != "") {
|
||||||
|
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default ports.
|
||||||
|
if *listenStr == "" {
|
||||||
|
defaultPort := ":8080"
|
||||||
|
if useTLS {
|
||||||
|
defaultPort = ":8443"
|
||||||
|
}
|
||||||
|
listenStr = &defaultPort
|
||||||
|
}
|
||||||
|
|
||||||
// Setup channels for server management
|
// Setup channels for server management
|
||||||
reloadChan := make(chan *proxy.ProxyManager)
|
|
||||||
exitChan := make(chan struct{})
|
exitChan := make(chan struct{})
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
// Create server with initial handler
|
// Create server with initial handler
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: *listenStr,
|
Addr: *listenStr,
|
||||||
Handler: proxyManager,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start server
|
// Support for watching config and reloading when it changes
|
||||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
reloadProxyManager := func() {
|
||||||
go func() {
|
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
conf, err = config.LoadConfig(*configPath)
|
||||||
fmt.Printf("Fatal server error: %v\n", err)
|
if err != nil {
|
||||||
close(exitChan)
|
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Handle config reloads and signals
|
|
||||||
go func() {
|
|
||||||
currentManager := proxyManager
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case newManager := <-reloadChan:
|
|
||||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
|
||||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
|
||||||
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
|
||||||
// Now do a full shutdown to clear the process map
|
|
||||||
currentManager.Shutdown()
|
|
||||||
currentManager = newManager
|
|
||||||
srv.Handler = newManager
|
|
||||||
log.Println("Server handler updated with new config")
|
|
||||||
case sig := <-sigChan:
|
|
||||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
currentManager.Shutdown()
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
|
||||||
fmt.Printf("Server shutdown error: %v\n", err)
|
|
||||||
}
|
|
||||||
close(exitChan)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Start file watcher if requested
|
fmt.Println("Configuration Changed")
|
||||||
if *watchConfig {
|
currentPM.Shutdown()
|
||||||
absConfigPath, err := filepath.Abs(*configPath)
|
newPM := proxy.New(conf)
|
||||||
if err != nil {
|
newPM.SetVersion(date, commit, version)
|
||||||
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
|
srv.Handler = newPM
|
||||||
|
fmt.Println("Configuration Reloaded")
|
||||||
|
|
||||||
|
// wait a few seconds and tell any UI to reload
|
||||||
|
time.AfterFunc(3*time.Second, func() {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateEnd,
|
||||||
|
})
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
go watchConfigFileWithReload(absConfigPath, reloadChan)
|
conf, err = config.LoadConfig(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
newPM := proxy.New(conf)
|
||||||
|
newPM.SetVersion(date, commit, version)
|
||||||
|
srv.Handler = newPM
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// load the initial proxy manager
|
||||||
|
reloadProxyManager()
|
||||||
|
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||||
|
if *watchConfig {
|
||||||
|
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||||
|
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||||
|
debouncedReload()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
fmt.Println("Watching Configuration for changes")
|
||||||
|
go func() {
|
||||||
|
absConfigPath, err := filepath.Abs(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
configDir := filepath.Dir(absConfigPath)
|
||||||
|
err = watcher.Add(configDir)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer watcher.Close()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case changeEvent := <-watcher.Events:
|
||||||
|
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
||||||
|
// the change for k8s configmap
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case err := <-watcher.Errors:
|
||||||
|
log.Printf("File watcher error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown on signal
|
||||||
|
go func() {
|
||||||
|
sig := <-sigChan
|
||||||
|
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
|
pm.Shutdown()
|
||||||
|
} else {
|
||||||
|
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
fmt.Printf("Server shutdown error: %v\n", err)
|
||||||
|
}
|
||||||
|
close(exitChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
go func() {
|
||||||
|
var err error
|
||||||
|
if useTLS {
|
||||||
|
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
|
||||||
|
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
|
||||||
|
err = srv.ListenAndServe()
|
||||||
|
}
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("Fatal server error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Wait for exit signal
|
// Wait for exit signal
|
||||||
<-exitChan
|
<-exitChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
|
func debounce(interval time.Duration, f func()) func() {
|
||||||
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
|
var timer *time.Timer
|
||||||
watcher, err := fsnotify.NewWatcher()
|
return func() {
|
||||||
if err != nil {
|
if timer != nil {
|
||||||
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
|
timer.Stop()
|
||||||
return
|
|
||||||
}
|
|
||||||
defer watcher.Close()
|
|
||||||
|
|
||||||
err = watcher.Add(configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Watching config file for changes: %s", configPath)
|
|
||||||
|
|
||||||
var debounceTimer *time.Timer
|
|
||||||
debounceDuration := 2 * time.Second
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case event, ok := <-watcher.Events:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// We only care about writes to the specific config file
|
|
||||||
if event.Name == configPath && event.Has(fsnotify.Write) {
|
|
||||||
// Reset or start the debounce timer
|
|
||||||
if debounceTimer != nil {
|
|
||||||
debounceTimer.Stop()
|
|
||||||
}
|
|
||||||
debounceTimer = time.AfterFunc(debounceDuration, func() {
|
|
||||||
log.Printf("Config file modified: %s, reloading...", event.Name)
|
|
||||||
|
|
||||||
// Try up to 3 times with exponential backoff
|
|
||||||
var newConfig proxy.Config
|
|
||||||
var err error
|
|
||||||
for retries := 0; retries < 3; retries++ {
|
|
||||||
// Load new configuration
|
|
||||||
newConfig, err = proxy.LoadConfig(configPath)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
|
|
||||||
if retries < 2 {
|
|
||||||
time.Sleep(time.Duration(1<<retries) * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to load new config after retries: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new ProxyManager with new config
|
|
||||||
newPM := proxy.New(newConfig)
|
|
||||||
reloadChan <- newPM
|
|
||||||
log.Println("Config reloaded successfully")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case err, ok := <-watcher.Errors:
|
|
||||||
if !ok {
|
|
||||||
log.Println("File watcher error channel closed.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("File watcher error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
timer = time.AfterFunc(interval, f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 51 KiB |
@@ -0,0 +1 @@
|
|||||||
|
ui_dist/*
|
||||||
@@ -1,247 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/google/shlex"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DEFAULT_GROUP_ID = "(default)"
|
|
||||||
|
|
||||||
type ModelConfig struct {
|
|
||||||
Cmd string `yaml:"cmd"`
|
|
||||||
Proxy string `yaml:"proxy"`
|
|
||||||
Aliases []string `yaml:"aliases"`
|
|
||||||
Env []string `yaml:"env"`
|
|
||||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
|
||||||
UnloadAfter int `yaml:"ttl"`
|
|
||||||
Unlisted bool `yaml:"unlisted"`
|
|
||||||
UseModelName string `yaml:"useModelName"`
|
|
||||||
|
|
||||||
// Limit concurrency of HTTP requests to process
|
|
||||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|
||||||
return SanitizeCommand(m.Cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
type GroupConfig struct {
|
|
||||||
Swap bool `yaml:"swap"`
|
|
||||||
Exclusive bool `yaml:"exclusive"`
|
|
||||||
Persistent bool `yaml:"persistent"`
|
|
||||||
Members []string `yaml:"members"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// set default values for GroupConfig
|
|
||||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
||||||
type rawGroupConfig GroupConfig
|
|
||||||
defaults := rawGroupConfig{
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: true,
|
|
||||||
Persistent: false,
|
|
||||||
Members: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unmarshal(&defaults); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
*c = GroupConfig(defaults)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
|
||||||
LogRequests bool `yaml:"logRequests"`
|
|
||||||
LogLevel string `yaml:"logLevel"`
|
|
||||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
|
||||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
|
||||||
|
|
||||||
// map aliases to actual model IDs
|
|
||||||
aliases map[string]string
|
|
||||||
|
|
||||||
// automatic port assignments
|
|
||||||
StartPort int `yaml:"startPort"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
|
||||||
if _, found := c.Models[search]; found {
|
|
||||||
return search, true
|
|
||||||
} else if name, found := c.aliases[search]; found {
|
|
||||||
return name, found
|
|
||||||
} else {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
|
||||||
if realName, found := c.RealModelName(modelName); !found {
|
|
||||||
return ModelConfig{}, "", false
|
|
||||||
} else {
|
|
||||||
return c.Models[realName], realName, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfig(path string) (Config, error) {
|
|
||||||
file, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
return LoadConfigFromReader(file)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var config Config
|
|
||||||
err = yaml.Unmarshal(data, &config)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
|
||||||
config.HealthCheckTimeout = 15
|
|
||||||
}
|
|
||||||
|
|
||||||
// set default port ranges
|
|
||||||
if config.StartPort == 0 {
|
|
||||||
// default to 5800
|
|
||||||
config.StartPort = 5800
|
|
||||||
} else if config.StartPort < 1 {
|
|
||||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate the aliases map
|
|
||||||
config.aliases = make(map[string]string)
|
|
||||||
for modelName, modelConfig := range config.Models {
|
|
||||||
for _, alias := range modelConfig.Aliases {
|
|
||||||
if _, found := config.aliases[alias]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
|
||||||
}
|
|
||||||
config.aliases[alias] = modelName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// iterate over the models and replace any ${PORT} with the next available port
|
|
||||||
// Get and sort all model IDs first, makes testing more consistent
|
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
|
||||||
for modelId := range config.Models {
|
|
||||||
modelIds = append(modelIds, modelId)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
|
||||||
|
|
||||||
// iterate over the sorted models
|
|
||||||
nextPort := config.StartPort
|
|
||||||
for _, modelId := range modelIds {
|
|
||||||
modelConfig := config.Models[modelId]
|
|
||||||
if strings.Contains(modelConfig.Cmd, "${PORT}") {
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
|
|
||||||
if modelConfig.Proxy == "" {
|
|
||||||
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
|
|
||||||
} else {
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
|
|
||||||
}
|
|
||||||
nextPort++
|
|
||||||
config.Models[modelId] = modelConfig
|
|
||||||
} else if modelConfig.Proxy == "" {
|
|
||||||
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
|
||||||
// check that members are all unique in the groups
|
|
||||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
|
||||||
for groupID, groupConfig := range config.Groups {
|
|
||||||
prevSet := make(map[string]bool)
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
// Check for duplicates within this group
|
|
||||||
if _, found := prevSet[member]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
|
||||||
}
|
|
||||||
prevSet[member] = true
|
|
||||||
|
|
||||||
// Check if member is used in another group
|
|
||||||
if existingGroup, exists := memberUsage[member]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
|
||||||
}
|
|
||||||
memberUsage[member] = groupID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewrites the yaml to include a default group with any orphaned models
|
|
||||||
func AddDefaultGroupToConfig(config Config) Config {
|
|
||||||
|
|
||||||
if config.Groups == nil {
|
|
||||||
config.Groups = make(map[string]GroupConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultGroup := GroupConfig{
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: true,
|
|
||||||
Members: []string{},
|
|
||||||
}
|
|
||||||
// if groups is empty, create a default group and put
|
|
||||||
// all models into it
|
|
||||||
if len(config.Groups) == 0 {
|
|
||||||
for modelName := range config.Models {
|
|
||||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// iterate over existing group members and add non-grouped models into the default group
|
|
||||||
for modelName, _ := range config.Models {
|
|
||||||
foundModel := false
|
|
||||||
found:
|
|
||||||
// search for the model in existing groups
|
|
||||||
for _, groupConfig := range config.Groups {
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
if member == modelName {
|
|
||||||
foundModel = true
|
|
||||||
break found
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !foundModel {
|
|
||||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
|
||||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
|
||||||
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
|
||||||
// Remove trailing backslashes
|
|
||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
|
|
||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
|
||||||
|
|
||||||
// Split the command into arguments
|
|
||||||
args, err := shlex.Split(cmdStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the command is not empty
|
|
||||||
if len(args) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty command")
|
|
||||||
}
|
|
||||||
|
|
||||||
return args, nil
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,734 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/billziss-gh/golib/shlex"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DEFAULT_GROUP_ID = "(default)"
|
||||||
|
const (
|
||||||
|
LogToStdoutProxy = "proxy"
|
||||||
|
LogToStdoutUpstream = "upstream"
|
||||||
|
LogToStdoutBoth = "both"
|
||||||
|
LogToStdoutNone = "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MacroEntry struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
|
||||||
|
type MacroList []MacroEntry
|
||||||
|
|
||||||
|
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||||
|
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
if value.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("macros must be a mapping")
|
||||||
|
}
|
||||||
|
|
||||||
|
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||||
|
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||||
|
for i := 0; i < len(value.Content); i += 2 {
|
||||||
|
keyNode := value.Content[i]
|
||||||
|
valueNode := value.Content[i+1]
|
||||||
|
|
||||||
|
var name string
|
||||||
|
if err := keyNode.Decode(&name); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var val any
|
||||||
|
if err := valueNode.Decode(&val); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||||
|
}
|
||||||
|
|
||||||
|
*ml = entries
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a macro value by name
|
||||||
|
func (ml MacroList) Get(name string) (any, bool) {
|
||||||
|
for _, entry := range ml {
|
||||||
|
if entry.Name == name {
|
||||||
|
return entry.Value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||||
|
func (ml MacroList) ToMap() map[string]any {
|
||||||
|
result := make(map[string]any, len(ml))
|
||||||
|
for _, entry := range ml {
|
||||||
|
result[entry.Name] = entry.Value
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroupConfig struct {
|
||||||
|
Swap bool `yaml:"swap"`
|
||||||
|
Exclusive bool `yaml:"exclusive"`
|
||||||
|
Persistent bool `yaml:"persistent"`
|
||||||
|
Members []string `yaml:"members"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
|
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||||
|
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// set default values for GroupConfig
|
||||||
|
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawGroupConfig GroupConfig
|
||||||
|
defaults := rawGroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Persistent: false,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = GroupConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HooksConfig struct {
|
||||||
|
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HookOnStartup struct {
|
||||||
|
Preload []string `yaml:"preload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
|
LogRequests bool `yaml:"logRequests"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||||
|
LogToStdout string `yaml:"logToStdout"`
|
||||||
|
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||||
|
CaptureBuffer int `yaml:"captureBuffer"`
|
||||||
|
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||||
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||||
|
|
||||||
|
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||||
|
Macros MacroList `yaml:"macros"`
|
||||||
|
|
||||||
|
// map aliases to actual model IDs
|
||||||
|
aliases map[string]string
|
||||||
|
|
||||||
|
// automatic port assignments
|
||||||
|
StartPort int `yaml:"startPort"`
|
||||||
|
|
||||||
|
// hooks, see: #209
|
||||||
|
Hooks HooksConfig `yaml:"hooks"`
|
||||||
|
|
||||||
|
// send loading state in reasoning
|
||||||
|
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||||
|
|
||||||
|
// present aliases to /v1/models OpenAI API listing
|
||||||
|
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||||
|
|
||||||
|
// support API keys, see issue #433, #50, #251
|
||||||
|
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||||
|
|
||||||
|
// support remote peers, see issue #433, #296
|
||||||
|
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
|
if _, found := c.Models[search]; found {
|
||||||
|
return search, true
|
||||||
|
} else if name, found := c.aliases[search]; found {
|
||||||
|
return name, found
|
||||||
|
} else {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||||
|
if realName, found := c.RealModelName(modelName); !found {
|
||||||
|
return ModelConfig{}, "", false
|
||||||
|
} else {
|
||||||
|
return c.Models[realName], realName, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfig(path string) (Config, error) {
|
||||||
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
return LoadConfigFromReader(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
yamlStr := string(data)
|
||||||
|
|
||||||
|
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||||
|
// This is safe because env values are simple strings without YAML formatting
|
||||||
|
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal into full Config with defaults
|
||||||
|
config := Config{
|
||||||
|
HealthCheckTimeout: 120,
|
||||||
|
StartPort: 5800,
|
||||||
|
LogLevel: "info",
|
||||||
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
|
}
|
||||||
|
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HealthCheckTimeout < 15 {
|
||||||
|
config.HealthCheckTimeout = 15
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.LogToStdout {
|
||||||
|
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the aliases map
|
||||||
|
config.aliases = make(map[string]string)
|
||||||
|
for modelName, modelConfig := range config.Models {
|
||||||
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
|
config.aliases[alias] = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate global macros
|
||||||
|
for _, macro := range config.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and sort all model IDs for consistent port assignment
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds)
|
||||||
|
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
|
||||||
|
// Strip comments from command fields
|
||||||
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
|
// Validate model macros
|
||||||
|
for _, macro := range modelConfig.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
mergedMacros = append(mergedMacros, config.Macros...)
|
||||||
|
|
||||||
|
// Add model macros (override globals with same name)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitute remaining macros in model fields (LIFO order)
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in metadata (type-preserving)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle PORT macro - only allocate if cmd uses it
|
||||||
|
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||||
|
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||||
|
if cmdHasPort || proxyHasPort {
|
||||||
|
if !cmdHasPort && proxyHasPort {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
macroSlug := "${PORT}"
|
||||||
|
macroStr := fmt.Sprintf("%v", nextPort)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
fieldMap := map[string]string{
|
||||||
|
"cmd": modelConfig.Cmd,
|
||||||
|
"cmdStop": modelConfig.CmdStop,
|
||||||
|
"proxy": modelConfig.Proxy,
|
||||||
|
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||||
|
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, fieldValue := range fieldMap {
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
|
continue // replaced at runtime
|
||||||
|
}
|
||||||
|
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||||
|
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelConfig.SendLoadingState == nil {
|
||||||
|
v := config.SendLoadingState
|
||||||
|
modelConfig.SendLoadingState = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
|
||||||
|
// Validate group members
|
||||||
|
memberUsage := make(map[string]string)
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up hooks preload
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
var toPreload []string
|
||||||
|
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if real, found := config.RealModelName(modelID); found {
|
||||||
|
toPreload = append(toPreload, real)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate API keys (env macros already substituted at string level)
|
||||||
|
for i, apikey := range config.RequiredAPIKeys {
|
||||||
|
if apikey == "" {
|
||||||
|
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||||
|
}
|
||||||
|
if strings.Contains(apikey, " ") {
|
||||||
|
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||||
|
}
|
||||||
|
config.RequiredAPIKeys[i] = apikey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process peers with global macro substitution
|
||||||
|
for peerName, peerConfig := range config.Peers {
|
||||||
|
// Substitute global macros (LIFO order)
|
||||||
|
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||||
|
entry := config.Macros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||||
|
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in setParams (type-preserving)
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||||
|
}
|
||||||
|
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Peers[peerName] = peerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewrites the yaml to include a default group with any orphaned models
|
||||||
|
func AddDefaultGroupToConfig(config Config) Config {
|
||||||
|
|
||||||
|
if config.Groups == nil {
|
||||||
|
config.Groups = make(map[string]GroupConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultGroup := GroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
// if groups is empty, create a default group and put
|
||||||
|
// all models into it
|
||||||
|
if len(config.Groups) == 0 {
|
||||||
|
for modelName := range config.Models {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// iterate over existing group members and add non-grouped models into the default group
|
||||||
|
for modelName := range config.Models {
|
||||||
|
foundModel := false
|
||||||
|
found:
|
||||||
|
// search for the model in existing groups
|
||||||
|
for _, groupConfig := range config.Groups {
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if member == modelName {
|
||||||
|
foundModel = true
|
||||||
|
break found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundModel {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||||
|
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle trailing backslashes by replacing with space
|
||||||
|
if strings.HasSuffix(trimmed, "\\") {
|
||||||
|
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put it back together
|
||||||
|
cmdStr = strings.Join(cleanedLines, "\n")
|
||||||
|
|
||||||
|
// Split the command into arguments
|
||||||
|
var args []string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
args = shlex.Windows.Split(cmdStr)
|
||||||
|
} else {
|
||||||
|
args = shlex.Posix.Split(cmdStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the command is not empty
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty command")
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripComments(cmdStr string) string {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
return strings.Join(cleanedLines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateMacro validates macro name and value constraints
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
if len(name) >= 64 {
|
||||||
|
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||||
|
}
|
||||||
|
if !macroNameRegex.MatchString(name) {
|
||||||
|
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that value is a scalar type
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
if len(v) >= 1024 {
|
||||||
|
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||||
|
}
|
||||||
|
// Check for self-reference
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||||
|
// These types are allowed
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "PORT", "MODEL_ID":
|
||||||
|
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||||
|
func validateNestedForUnknownMacros(value any, context string) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||||
|
}
|
||||||
|
// Check for unsubstituted env macros
|
||||||
|
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range envMatches {
|
||||||
|
varName := match[1]
|
||||||
|
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Scalar types don't contain macros
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||||
|
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||||
|
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||||
|
// (which strips comments) and only checking the comment-free version for macros.
|
||||||
|
func substituteEnvMacros(s string) (string, error) {
|
||||||
|
// Unmarshal and remarshal to strip YAML comments
|
||||||
|
var raw any
|
||||||
|
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||||
|
// If YAML is invalid, fall back to scanning the original string
|
||||||
|
// so the user gets the env var error rather than a confusing YAML parse error
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
clean, err := yaml.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return substituteEnvMacrosInString(s, string(clean))
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||||
|
// them in target. This separation allows scanning comment-free YAML while
|
||||||
|
// substituting in the original string.
|
||||||
|
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||||
|
result := target
|
||||||
|
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
fullMatch := match[0] // ${env.VAR_NAME}
|
||||||
|
varName := match[1] // VAR_NAME
|
||||||
|
|
||||||
|
value, exists := os.LookupEnv(varName)
|
||||||
|
if !exists {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the value for safe YAML substitution
|
||||||
|
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = strings.ReplaceAll(result, fullMatch, value)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||||
|
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||||
|
// for compatibility with double-quoted YAML strings.
|
||||||
|
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||||
|
// Reject values that would break YAML structure regardless of quoting context
|
||||||
|
if strings.ContainsAny(value, "\n\r\x00") {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||||
|
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||||
|
// In double-quoted contexts, they are interpreted correctly.
|
||||||
|
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||||
|
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,253 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
|
// Test a command with spaces and newlines
|
||||||
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
|
-a "double quotes" \
|
||||||
|
--arg2 'single quotes'
|
||||||
|
-s
|
||||||
|
# comment 1
|
||||||
|
--arg3 123 \
|
||||||
|
|
||||||
|
# comment 2
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
|
||||||
|
|
||||||
|
# this will get stripped out as well as the white space above
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"--arg2", "single quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", `"string in string"`,
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
|
// Test an empty command
|
||||||
|
args, err = SanitizeCommand("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the default values are automatically set for global, model and group configurations
|
||||||
|
// after loading the configuration
|
||||||
|
func TestConfig_DefaultValuesPosix(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "info", config.LogLevel)
|
||||||
|
assert.Equal(t, "", config.LogTimeFormat)
|
||||||
|
|
||||||
|
// Test default group exists
|
||||||
|
defaultGroup, exists := config.Groups["(default)"]
|
||||||
|
assert.True(t, exists, "default group should exist")
|
||||||
|
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||||
|
assert.Equal(t, true, defaultGroup.Swap)
|
||||||
|
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||||
|
assert.Equal(t, false, defaultGroup.Persistent)
|
||||||
|
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||||
|
}
|
||||||
|
|
||||||
|
model1, exists := config.Models["model1"]
|
||||||
|
assert.True(t, exists, "model1 should exist")
|
||||||
|
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||||
|
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||||
|
assert.Equal(t, "", model1.CmdStop)
|
||||||
|
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||||
|
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||||
|
assert.Equal(t, []string{}, model1.Aliases)
|
||||||
|
assert.Equal(t, []string{}, model1.Env)
|
||||||
|
assert.Equal(t, 0, model1.UnloadAfter)
|
||||||
|
assert.Equal(t, false, model1.Unlisted)
|
||||||
|
assert.Equal(t, "", model1.UseModelName)
|
||||||
|
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// default empty filter exists
|
||||||
|
assert.Equal(t, "", model1.Filters.StripParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_LoadPosix(t *testing.T) {
|
||||||
|
// Create a temporary YAML file for testing
|
||||||
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
hooks:
|
||||||
|
on_startup:
|
||||||
|
preload: ["model1", "model2"]
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
name: "Model 1"
|
||||||
|
description: "This is model 1"
|
||||||
|
aliases:
|
||||||
|
- "m1"
|
||||||
|
- "model-one"
|
||||||
|
env:
|
||||||
|
- "VAR1=value1"
|
||||||
|
- "VAR2=value2"
|
||||||
|
checkEndpoint: "/health"
|
||||||
|
model2:
|
||||||
|
cmd: ${svr-path} --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "m2"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "mthree"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model4:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8082"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
profiles:
|
||||||
|
test:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
forever:
|
||||||
|
exclusive: false
|
||||||
|
persistent: true
|
||||||
|
members:
|
||||||
|
- "model4"
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write temporary file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the config and verify
|
||||||
|
config, err := LoadConfig(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelLoadingState := false
|
||||||
|
|
||||||
|
expected := Config{
|
||||||
|
LogLevel: "info",
|
||||||
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
|
StartPort: 5800,
|
||||||
|
Macros: MacroList{
|
||||||
|
{"svr-path", "path/to/server"},
|
||||||
|
},
|
||||||
|
Hooks: HooksConfig{
|
||||||
|
OnStartup: HookOnStartup{
|
||||||
|
Preload: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SendLoadingState: false,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
Aliases: []string{"m1", "model-one"},
|
||||||
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
Name: "Model 1",
|
||||||
|
Description: "This is model 1",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "path/to/server --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
|
},
|
||||||
|
"model3": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"mthree"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
|
},
|
||||||
|
"model4": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8082",
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
"mthree": "model3",
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, config)
|
||||||
|
|
||||||
|
realname, found := config.RealModelName("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", realname)
|
||||||
|
}
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
|
// does not support single quoted strings like in config_posix_test.go
|
||||||
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
|
|
||||||
|
-a "double quotes" \
|
||||||
|
-s
|
||||||
|
--arg3 123 \
|
||||||
|
|
||||||
|
# comment 2
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# this will get stripped out as well as the white space above
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
|
// Test an empty command
|
||||||
|
args, err = SanitizeCommand("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_DefaultValuesWindows(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "info", config.LogLevel)
|
||||||
|
assert.Equal(t, "", config.LogTimeFormat)
|
||||||
|
|
||||||
|
// Test default group exists
|
||||||
|
defaultGroup, exists := config.Groups["(default)"]
|
||||||
|
assert.True(t, exists, "default group should exist")
|
||||||
|
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||||
|
assert.Equal(t, true, defaultGroup.Swap)
|
||||||
|
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||||
|
assert.Equal(t, false, defaultGroup.Persistent)
|
||||||
|
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||||
|
}
|
||||||
|
|
||||||
|
model1, exists := config.Models["model1"]
|
||||||
|
assert.True(t, exists, "model1 should exist")
|
||||||
|
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||||
|
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||||
|
assert.Equal(t, "taskkill /f /t /pid ${PID}", model1.CmdStop)
|
||||||
|
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||||
|
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||||
|
assert.Equal(t, []string{}, model1.Aliases)
|
||||||
|
assert.Equal(t, []string{}, model1.Env)
|
||||||
|
assert.Equal(t, 0, model1.UnloadAfter)
|
||||||
|
assert.Equal(t, false, model1.Unlisted)
|
||||||
|
assert.Equal(t, "", model1.UseModelName)
|
||||||
|
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// default empty filter exists
|
||||||
|
assert.Equal(t, "", model1.Filters.StripParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_LoadWindows(t *testing.T) {
|
||||||
|
// Create a temporary YAML file for testing
|
||||||
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
aliases:
|
||||||
|
- "m1"
|
||||||
|
- "model-one"
|
||||||
|
env:
|
||||||
|
- "VAR1=value1"
|
||||||
|
- "VAR2=value2"
|
||||||
|
checkEndpoint: "/health"
|
||||||
|
model2:
|
||||||
|
cmd: ${svr-path} --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "m2"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "mthree"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model4:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8082"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
profiles:
|
||||||
|
test:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
forever:
|
||||||
|
exclusive: false
|
||||||
|
persistent: true
|
||||||
|
members:
|
||||||
|
- "model4"
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write temporary file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the config and verify
|
||||||
|
config, err := LoadConfig(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelLoadingState := false
|
||||||
|
|
||||||
|
expected := Config{
|
||||||
|
LogLevel: "info",
|
||||||
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
|
StartPort: 5800,
|
||||||
|
Macros: MacroList{
|
||||||
|
{"svr-path", "path/to/server"},
|
||||||
|
},
|
||||||
|
SendLoadingState: false,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
Aliases: []string{"m1", "model-one"},
|
||||||
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "path/to/server --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
|
},
|
||||||
|
"model3": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"mthree"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
|
},
|
||||||
|
"model4": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8082",
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
"mthree": "model3",
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, config)
|
||||||
|
|
||||||
|
realname, found := config.RealModelName("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", realname)
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||||
|
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||||
|
var ProtectedParams = []string{"model"}
|
||||||
|
|
||||||
|
// Filters contains filter settings for modifying request parameters
|
||||||
|
// Used by both models and peers
|
||||||
|
type Filters struct {
|
||||||
|
// StripParams is a comma-separated list of parameters to remove from requests
|
||||||
|
// The "model" parameter can never be removed
|
||||||
|
StripParams string `yaml:"stripParams"`
|
||||||
|
|
||||||
|
// SetParams is a dictionary of parameters to set/override in requests
|
||||||
|
// Protected params (like "model") cannot be set
|
||||||
|
SetParams map[string]any `yaml:"setParams"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||||
|
// with duplicates, empty strings, and protected params removed
|
||||||
|
func (f Filters) SanitizedStripParams() []string {
|
||||||
|
if f.StripParams == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
params := strings.Split(f.StripParams, ",")
|
||||||
|
cleaned := make([]string, 0, len(params))
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, param := range params {
|
||||||
|
trimmed := strings.TrimSpace(param)
|
||||||
|
// Skip protected params, empty strings, and duplicates
|
||||||
|
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[trimmed] = true
|
||||||
|
cleaned = append(cleaned, trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cleaned) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Sort(cleaned)
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||||
|
// and keys sorted for consistent iteration order
|
||||||
|
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||||
|
if len(f.SetParams) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]any, len(f.SetParams))
|
||||||
|
keys := make([]string, 0, len(f.SetParams))
|
||||||
|
|
||||||
|
for key, value := range f.SetParams {
|
||||||
|
// Skip protected params
|
||||||
|
if slices.Contains(ProtectedParams, key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result[key] = value
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort keys for consistent ordering
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, keys
|
||||||
|
}
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
stripParams string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
stripParams: "",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single param",
|
||||||
|
stripParams: "temperature",
|
||||||
|
want: []string{"temperature"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple params",
|
||||||
|
stripParams: "temperature, top_p, top_k",
|
||||||
|
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model param filtered",
|
||||||
|
stripParams: "model, temperature, top_p",
|
||||||
|
want: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only model param",
|
||||||
|
stripParams: "model",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicates removed",
|
||||||
|
stripParams: "temperature, top_p, temperature",
|
||||||
|
want: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra whitespace",
|
||||||
|
stripParams: " temperature , top_p ",
|
||||||
|
want: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty values filtered",
|
||||||
|
stripParams: "temperature,,top_p,",
|
||||||
|
want: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
f := Filters{StripParams: tt.stripParams}
|
||||||
|
got := f.SanitizedStripParams()
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setParams map[string]any
|
||||||
|
wantParams map[string]any
|
||||||
|
wantKeys []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty setParams",
|
||||||
|
setParams: nil,
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty map",
|
||||||
|
setParams: map[string]any{},
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal params",
|
||||||
|
setParams: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
wantKeys: []string{"temperature", "top_p"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "protected model param filtered",
|
||||||
|
setParams: map[string]any{
|
||||||
|
"model": "should-be-filtered",
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
wantKeys: []string{"temperature"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only protected param",
|
||||||
|
setParams: map[string]any{
|
||||||
|
"model": "should-be-filtered",
|
||||||
|
},
|
||||||
|
wantParams: nil,
|
||||||
|
wantKeys: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex nested values",
|
||||||
|
setParams: map[string]any{
|
||||||
|
"provider": map[string]any{
|
||||||
|
"data_collection": "deny",
|
||||||
|
"allow_fallbacks": false,
|
||||||
|
},
|
||||||
|
"transforms": []string{"middle-out"},
|
||||||
|
},
|
||||||
|
wantParams: map[string]any{
|
||||||
|
"provider": map[string]any{
|
||||||
|
"data_collection": "deny",
|
||||||
|
"allow_fallbacks": false,
|
||||||
|
},
|
||||||
|
"transforms": []string{"middle-out"},
|
||||||
|
},
|
||||||
|
wantKeys: []string{"provider", "transforms"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
f := Filters{SetParams: tt.setParams}
|
||||||
|
gotParams, gotKeys := f.SanitizedSetParams()
|
||||||
|
|
||||||
|
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||||
|
for i, key := range gotKeys {
|
||||||
|
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantParams == nil {
|
||||||
|
assert.Nil(t, gotParams, "expected nil params")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||||
|
for key, wantValue := range tt.wantParams {
|
||||||
|
gotValue, exists := gotParams[key]
|
||||||
|
assert.True(t, exists, "missing key: %s", key)
|
||||||
|
// Simple comparison for basic types
|
||||||
|
switch v := wantValue.(type) {
|
||||||
|
case string, int, float64, bool:
|
||||||
|
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtectedParams(t *testing.T) {
|
||||||
|
// Verify that "model" is protected
|
||||||
|
assert.Contains(t, ProtectedParams, "model")
|
||||||
|
}
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test macro-in-macro basic substitution
|
||||||
|
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"A": "value-A"
|
||||||
|
"B": "prefix-${A}-suffix"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${B}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test LIFO substitution order with 3+ macro levels
|
||||||
|
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"base": "/models"
|
||||||
|
"path": "${base}/llama"
|
||||||
|
"full": "${path}/model.gguf"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: load ${full}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MODEL_ID in global macro used by model
|
||||||
|
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: ${podman-llama} -m model.gguf
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test model macro overrides global macro in substitution
|
||||||
|
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"tag": "global"
|
||||||
|
"msg": "value-${tag}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
"tag": "model-level"
|
||||||
|
cmd: echo ${msg}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test self-reference detection error
|
||||||
|
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"recursive": "value-${recursive}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${recursive}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "recursive")
|
||||||
|
assert.Contains(t, err.Error(), "self-reference")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test undefined macro reference error
|
||||||
|
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"A": "value-${UNDEFINED}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${A}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||||
|
}
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelConfig struct {
|
||||||
|
Cmd string `yaml:"cmd"`
|
||||||
|
CmdStop string `yaml:"cmdStop"`
|
||||||
|
Proxy string `yaml:"proxy"`
|
||||||
|
Aliases []string `yaml:"aliases"`
|
||||||
|
Env []string `yaml:"env"`
|
||||||
|
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||||
|
UnloadAfter int `yaml:"ttl"`
|
||||||
|
Unlisted bool `yaml:"unlisted"`
|
||||||
|
UseModelName string `yaml:"useModelName"`
|
||||||
|
|
||||||
|
// #179 for /v1/models
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
|
||||||
|
// Limit concurrency of HTTP requests to process
|
||||||
|
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||||
|
|
||||||
|
// Model filters see issue #174
|
||||||
|
Filters ModelFilters `yaml:"filters"`
|
||||||
|
|
||||||
|
// Macros: see #264
|
||||||
|
// Model level macros take precedence over the global macros
|
||||||
|
Macros MacroList `yaml:"macros"`
|
||||||
|
|
||||||
|
// Metadata: see #264
|
||||||
|
// Arbitrary metadata that can be exposed through the API
|
||||||
|
Metadata map[string]any `yaml:"metadata"`
|
||||||
|
|
||||||
|
// override global setting
|
||||||
|
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawModelConfig ModelConfig
|
||||||
|
defaults := rawModelConfig{
|
||||||
|
Cmd: "",
|
||||||
|
CmdStop: "",
|
||||||
|
Proxy: "http://localhost:${PORT}",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
UnloadAfter: 0,
|
||||||
|
Unlisted: false,
|
||||||
|
UseModelName: "",
|
||||||
|
ConcurrencyLimit: 0,
|
||||||
|
Name: "",
|
||||||
|
Description: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = ModelConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
|
return SanitizeCommand(m.Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||||
|
// See issue #174
|
||||||
|
type ModelFilters struct {
|
||||||
|
Filters `yaml:",inline"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawModelFilters ModelFilters
|
||||||
|
defaults := rawModelFilters{}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to unmarshal with the old field name for backwards compatibility
|
||||||
|
if defaults.StripParams == "" {
|
||||||
|
var legacy struct {
|
||||||
|
StripParams string `yaml:"strip_params"`
|
||||||
|
}
|
||||||
|
if legacyErr := unmarshal(&legacy); legacyErr != nil {
|
||||||
|
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
|
||||||
|
}
|
||||||
|
defaults.StripParams = legacy.StripParams
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = ModelFilters(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||||
|
// Returns ([]string, error) to match existing API
|
||||||
|
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||||
|
return f.Filters.SanitizedStripParams(), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
|
config := &ModelConfig{
|
||||||
|
Cmd: `python model1.py \
|
||||||
|
--arg1 value1 \
|
||||||
|
--arg2 value2`,
|
||||||
|
}
|
||||||
|
|
||||||
|
args, err := config.SanitizedCommand()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelFilters(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
default_strip: "temperature, top_p"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
# macros inserted and list is cleaned of duplicates and empty strings
|
||||||
|
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||||
|
# check for strip_params (legacy field name) compatibility
|
||||||
|
legacy:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
for modelId, modelConfig := range config.Models {
|
||||||
|
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
|
||||||
|
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||||
|
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
// model has been removed
|
||||||
|
// empty strings have been removed
|
||||||
|
// duplicates have been removed
|
||||||
|
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelSendLoadingState(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
sendLoadingState: true
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
sendLoadingState: false
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, config.SendLoadingState)
|
||||||
|
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
|
||||||
|
assert.False(t, *config.Models["model1"].SendLoadingState)
|
||||||
|
}
|
||||||
|
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
|
||||||
|
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
stripParams: "top_k"
|
||||||
|
setParams:
|
||||||
|
temperature: 0.7
|
||||||
|
top_p: 0.9
|
||||||
|
stop:
|
||||||
|
- "<|end|>"
|
||||||
|
- "<|stop|>"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
modelConfig := config.Models["model1"]
|
||||||
|
|
||||||
|
// Check stripParams
|
||||||
|
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||||
|
|
||||||
|
// Check setParams
|
||||||
|
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||||
|
assert.NotNil(t, setParams)
|
||||||
|
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||||
|
assert.Equal(t, 0.7, setParams["temperature"])
|
||||||
|
assert.Equal(t, 0.9, setParams["top_p"])
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PeerDictionaryConfig map[string]PeerConfig
|
||||||
|
type PeerConfig struct {
|
||||||
|
Proxy string `yaml:"proxy"`
|
||||||
|
ProxyURL *url.URL `yaml:"-"`
|
||||||
|
ApiKey string `yaml:"apiKey"`
|
||||||
|
Models []string `yaml:"models"`
|
||||||
|
Filters Filters `yaml:"filters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawPeerConfig PeerConfig
|
||||||
|
defaults := rawPeerConfig{
|
||||||
|
Proxy: "",
|
||||||
|
ApiKey: "",
|
||||||
|
Models: []string{},
|
||||||
|
Filters: Filters{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate proxy is not empty
|
||||||
|
if defaults.Proxy == "" {
|
||||||
|
return fmt.Errorf("proxy is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate proxy is a valid URL and store the parsed value
|
||||||
|
parsedURL, err := url.Parse(defaults.Proxy)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||||
|
}
|
||||||
|
defaults.ProxyURL = parsedURL
|
||||||
|
|
||||||
|
// Validate models is not empty
|
||||||
|
if len(defaults.Models) == 0 {
|
||||||
|
return fmt.Errorf("peer models can not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = PeerConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
yaml string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
yaml: `
|
||||||
|
proxy: http://192.168.1.23
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
- model_b
|
||||||
|
`,
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config with apiKey",
|
||||||
|
yaml: `
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: sk-test-key
|
||||||
|
models:
|
||||||
|
- meta-llama/llama-3.1-8b-instruct
|
||||||
|
`,
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing proxy",
|
||||||
|
yaml: `
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
`,
|
||||||
|
wantErr: "proxy is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty proxy",
|
||||||
|
yaml: `
|
||||||
|
proxy: ""
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
`,
|
||||||
|
wantErr: "proxy is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid proxy URL",
|
||||||
|
yaml: `
|
||||||
|
proxy: "://invalid"
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
`,
|
||||||
|
wantErr: "invalid peer proxy URL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing models",
|
||||||
|
yaml: `
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`,
|
||||||
|
wantErr: "peer models can not be empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty models",
|
||||||
|
yaml: `
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
models: []
|
||||||
|
`,
|
||||||
|
wantErr: "peer models can not be empty",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var config PeerConfig
|
||||||
|
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||||
|
|
||||||
|
if tt.wantErr == "" {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||||
|
} else if !contains(err.Error(), tt.wantErr) {
|
||||||
|
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||||
|
yamlData := `
|
||||||
|
proxy: http://192.168.1.23:8080/api
|
||||||
|
apiKey: sk-test
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
`
|
||||||
|
var config PeerConfig
|
||||||
|
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL == nil {
|
||||||
|
t.Fatal("ProxyURL should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||||
|
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL.Scheme != "http" {
|
||||||
|
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL.Path != "/api" {
|
||||||
|
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func searchSubstring(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||||
|
yamlData := `
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: sk-test
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
filters:
|
||||||
|
setParams:
|
||||||
|
temperature: 0.7
|
||||||
|
provider:
|
||||||
|
data_collection: deny
|
||||||
|
`
|
||||||
|
var config PeerConfig
|
||||||
|
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Filters.SetParams == nil {
|
||||||
|
t.Fatal("Filters.SetParams should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||||
|
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("provider should be a map")
|
||||||
|
}
|
||||||
|
if provider["data_collection"] != "deny" {
|
||||||
|
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||||
|
yamlData := `
|
||||||
|
proxy: https://openrouter.ai/api
|
||||||
|
apiKey: sk-test
|
||||||
|
models:
|
||||||
|
- model_a
|
||||||
|
filters:
|
||||||
|
stripParams: "temperature, top_p"
|
||||||
|
setParams:
|
||||||
|
max_tokens: 1000
|
||||||
|
`
|
||||||
|
var config PeerConfig
|
||||||
|
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check stripParams
|
||||||
|
stripParams := config.Filters.SanitizedStripParams()
|
||||||
|
if len(stripParams) != 2 {
|
||||||
|
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||||
|
}
|
||||||
|
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||||
|
t.Errorf("unexpected strip params: %v", stripParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check setParams
|
||||||
|
if config.Filters.SetParams == nil {
|
||||||
|
t.Fatal("Filters.SetParams should not be nil")
|
||||||
|
}
|
||||||
|
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||||
|
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,361 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
|
||||||
// Create a temporary YAML file for testing
|
|
||||||
tempDir, err := os.MkdirTemp("", "test-config")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(tempDir)
|
|
||||||
|
|
||||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8080"
|
|
||||||
aliases:
|
|
||||||
- "m1"
|
|
||||||
- "model-one"
|
|
||||||
env:
|
|
||||||
- "VAR1=value1"
|
|
||||||
- "VAR2=value2"
|
|
||||||
checkEndpoint: "/health"
|
|
||||||
model2:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
aliases:
|
|
||||||
- "m2"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
model3:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
aliases:
|
|
||||||
- "mthree"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
model4:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8082"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
|
|
||||||
healthCheckTimeout: 15
|
|
||||||
profiles:
|
|
||||||
test:
|
|
||||||
- model1
|
|
||||||
- model2
|
|
||||||
groups:
|
|
||||||
group1:
|
|
||||||
swap: true
|
|
||||||
exclusive: false
|
|
||||||
members: ["model2"]
|
|
||||||
forever:
|
|
||||||
exclusive: false
|
|
||||||
persistent: true
|
|
||||||
members:
|
|
||||||
- "model4"
|
|
||||||
`
|
|
||||||
|
|
||||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write temporary file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the config and verify
|
|
||||||
config, err := LoadConfig(tempFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := Config{
|
|
||||||
StartPort: 5800,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8080",
|
|
||||||
Aliases: []string{"m1", "model-one"},
|
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
},
|
|
||||||
"model2": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8081",
|
|
||||||
Aliases: []string{"m2"},
|
|
||||||
Env: nil,
|
|
||||||
CheckEndpoint: "/",
|
|
||||||
},
|
|
||||||
"model3": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8081",
|
|
||||||
Aliases: []string{"mthree"},
|
|
||||||
Env: nil,
|
|
||||||
CheckEndpoint: "/",
|
|
||||||
},
|
|
||||||
"model4": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8082",
|
|
||||||
CheckEndpoint: "/",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"model1", "model2"},
|
|
||||||
},
|
|
||||||
aliases: map[string]string{
|
|
||||||
"m1": "model1",
|
|
||||||
"model-one": "model1",
|
|
||||||
"m2": "model2",
|
|
||||||
"mthree": "model3",
|
|
||||||
},
|
|
||||||
Groups: map[string]GroupConfig{
|
|
||||||
DEFAULT_GROUP_ID: {
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: true,
|
|
||||||
Members: []string{"model1", "model3"},
|
|
||||||
},
|
|
||||||
"group1": {
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: false,
|
|
||||||
Members: []string{"model2"},
|
|
||||||
},
|
|
||||||
"forever": {
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: false,
|
|
||||||
Persistent: true,
|
|
||||||
Members: []string{"model4"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, expected, config)
|
|
||||||
|
|
||||||
realname, found := config.RealModelName("m1")
|
|
||||||
assert.True(t, found)
|
|
||||||
assert.Equal(t, "model1", realname)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8080"
|
|
||||||
model2:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
model3:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
|
|
||||||
healthCheckTimeout: 15
|
|
||||||
groups:
|
|
||||||
group1:
|
|
||||||
swap: true
|
|
||||||
exclusive: false
|
|
||||||
members: ["model2"]
|
|
||||||
group2:
|
|
||||||
swap: true
|
|
||||||
exclusive: false
|
|
||||||
members: ["model2"]
|
|
||||||
`
|
|
||||||
// Load the config and verify
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
|
|
||||||
// a Contains as order of the map is not guaranteed
|
|
||||||
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8080"
|
|
||||||
aliases:
|
|
||||||
- m1
|
|
||||||
model2:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
aliases:
|
|
||||||
- m1
|
|
||||||
- m2
|
|
||||||
`
|
|
||||||
// Load the config and verify
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
|
|
||||||
// this is a contains because it could be `model1` or `model2` depending on the order
|
|
||||||
// go decided on the order of the map
|
|
||||||
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
|
||||||
config := &ModelConfig{
|
|
||||||
Cmd: `python model1.py \
|
|
||||||
--arg1 value1 \
|
|
||||||
--arg2 value2`,
|
|
||||||
}
|
|
||||||
|
|
||||||
args, err := config.SanitizedCommand()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_FindConfig(t *testing.T) {
|
|
||||||
|
|
||||||
// TODO?
|
|
||||||
// make make this shared between the different tests
|
|
||||||
config := &Config{
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": {
|
|
||||||
Cmd: "python model1.py",
|
|
||||||
Proxy: "http://localhost:8080",
|
|
||||||
Aliases: []string{"m1", "model-one"},
|
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
},
|
|
||||||
"model2": {
|
|
||||||
Cmd: "python model2.py",
|
|
||||||
Proxy: "http://localhost:8081",
|
|
||||||
Aliases: []string{"m2", "model-two"},
|
|
||||||
Env: []string{"VAR3=value3", "VAR4=value4"},
|
|
||||||
CheckEndpoint: "/status",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
HealthCheckTimeout: 10,
|
|
||||||
aliases: map[string]string{
|
|
||||||
"m1": "model1",
|
|
||||||
"model-one": "model1",
|
|
||||||
"m2": "model2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test finding a model by its name
|
|
||||||
modelConfig, modelId, found := config.FindConfig("model1")
|
|
||||||
assert.True(t, found)
|
|
||||||
assert.Equal(t, "model1", modelId)
|
|
||||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
|
||||||
|
|
||||||
// Test finding a model by its alias
|
|
||||||
modelConfig, modelId, found = config.FindConfig("m1")
|
|
||||||
assert.True(t, found)
|
|
||||||
assert.Equal(t, "model1", modelId)
|
|
||||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
|
||||||
|
|
||||||
// Test finding a model that does not exist
|
|
||||||
modelConfig, modelId, found = config.FindConfig("model3")
|
|
||||||
assert.False(t, found)
|
|
||||||
assert.Equal(t, "", modelId)
|
|
||||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
|
||||||
|
|
||||||
// Test a command with spaces and newlines
|
|
||||||
args, err := SanitizeCommand(`python model1.py \
|
|
||||||
-a "double quotes" \
|
|
||||||
--arg2 'single quotes'
|
|
||||||
-s
|
|
||||||
--arg3 123 \
|
|
||||||
--arg4 '"string in string"'
|
|
||||||
-c "'single quoted'"
|
|
||||||
`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{
|
|
||||||
"python", "model1.py",
|
|
||||||
"-a", "double quotes",
|
|
||||||
"--arg2", "single quotes",
|
|
||||||
"-s",
|
|
||||||
"--arg3", "123",
|
|
||||||
"--arg4", `"string in string"`,
|
|
||||||
"-c", `'single quoted'`,
|
|
||||||
}, args)
|
|
||||||
|
|
||||||
// Test an empty command
|
|
||||||
args, err = SanitizeCommand("")
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
|
||||||
|
|
||||||
t.Run("Default Port Ranges", func(t *testing.T) {
|
|
||||||
content := ``
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 5800, config.StartPort)
|
|
||||||
})
|
|
||||||
t.Run("User specific port ranges", func(t *testing.T) {
|
|
||||||
content := `startPort: 1000`
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 1000, config.StartPort)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Invalid start port", func(t *testing.T) {
|
|
||||||
content := `startPort: abcd`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("start port must be greater than 1", func(t *testing.T) {
|
|
||||||
content := `startPort: -99`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Automatic port assignments", func(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
startPort: 5800
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: svr --port ${PORT}
|
|
||||||
model2:
|
|
||||||
cmd: svr --port ${PORT}
|
|
||||||
proxy: "http://172.11.22.33:${PORT}"
|
|
||||||
model3:
|
|
||||||
cmd: svr --port 1999
|
|
||||||
proxy: "http://1.2.3.4:1999"
|
|
||||||
`
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 5800, config.StartPort)
|
|
||||||
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
|
||||||
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
|
||||||
|
|
||||||
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
|
||||||
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
|
||||||
|
|
||||||
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
|
||||||
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: svr --port 111
|
|
||||||
`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||||
|
type DiscardWriter struct {
|
||||||
|
header http.Header
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) Header() http.Header {
|
||||||
|
if w.header == nil {
|
||||||
|
w.header = make(http.Header)
|
||||||
|
}
|
||||||
|
return w.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) WriteHeader(code int) {
|
||||||
|
w.status = code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Satisfy the http.Flusher interface for streaming responses
|
||||||
|
func (w *DiscardWriter) Flush() {}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
// package level registry of the different event types
|
||||||
|
|
||||||
|
const ProcessStateChangeEventID = 0x01
|
||||||
|
const ChatCompletionStatsEventID = 0x02
|
||||||
|
const ConfigFileChangedEventID = 0x03
|
||||||
|
const LogDataEventID = 0x04
|
||||||
|
const TokenMetricsEventID = 0x05
|
||||||
|
const ModelPreloadedEventID = 0x06
|
||||||
|
|
||||||
|
type ProcessStateChangeEvent struct {
|
||||||
|
ProcessName string
|
||||||
|
NewState ProcessState
|
||||||
|
OldState ProcessState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||||
|
return ProcessStateChangeEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStats struct {
|
||||||
|
TokensGenerated int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ChatCompletionStats) Type() uint32 {
|
||||||
|
return ChatCompletionStatsEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReloadingState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReloadingStateStart ReloadingState = iota
|
||||||
|
ReloadingStateEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfigFileChangedEvent struct {
|
||||||
|
ReloadingState ReloadingState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||||
|
return ConfigFileChangedEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogDataEvent struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e LogDataEvent) Type() uint32 {
|
||||||
|
return LogDataEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelPreloadedEvent struct {
|
||||||
|
ModelName string
|
||||||
|
Success bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ModelPreloadedEvent) Type() uint32 {
|
||||||
|
return ModelPreloadedEventID
|
||||||
|
}
|
||||||
@@ -9,12 +9,15 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
nextTestPort int = 12000
|
nextTestPort int = 12000
|
||||||
portMutex sync.Mutex
|
portMutex sync.Mutex
|
||||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
simpleResponderPath = getSimpleResponderPath()
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check if the binary exists
|
// Check if the binary exists
|
||||||
@@ -45,7 +48,12 @@ func TestMain(m *testing.M) {
|
|||||||
func getSimpleResponderPath() string {
|
func getSimpleResponderPath() string {
|
||||||
goos := runtime.GOOS
|
goos := runtime.GOOS
|
||||||
goarch := runtime.GOARCH
|
goarch := runtime.GOARCH
|
||||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
|
||||||
|
if goos == "windows" {
|
||||||
|
return filepath.Join("..", "build", "simple-responder.exe")
|
||||||
|
} else {
|
||||||
|
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestPort() int {
|
func getTestPort() int {
|
||||||
@@ -58,17 +66,25 @@ func getTestPort() int {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||||
binaryPath := getSimpleResponderPath()
|
// Convert path to forward slashes for cross-platform compatibility
|
||||||
|
// Windows handles forward slashes in paths correctly
|
||||||
|
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||||
|
|
||||||
// Create a process configuration
|
// Create a YAML string with just the values we want to set
|
||||||
return ModelConfig{
|
yamlStr := fmt.Sprintf(`
|
||||||
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
cmd: '%s --port %d --silent --respond %s'
|
||||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
proxy: "http://127.0.0.1:%d"
|
||||||
CheckEndpoint: "/health",
|
`, cmdPath, port, expectedMessage, port)
|
||||||
|
|
||||||
|
var cfg config.ModelConfig
|
||||||
|
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
}
|
}
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 15 KiB |
@@ -1,14 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>llama-swap</title>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<h1>llama-swap</h1>
|
|
||||||
<p>
|
|
||||||
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
|
||||||
</p>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,259 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Logs</title>
|
|
||||||
<style>
|
|
||||||
body {
|
|
||||||
margin: 0;
|
|
||||||
height: 100vh;
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
font-family: "Courier New", Courier, monospace;
|
|
||||||
}
|
|
||||||
.log-container {
|
|
||||||
display: flex;
|
|
||||||
flex: 1;
|
|
||||||
gap: 0.5em;
|
|
||||||
margin: 0.5em;
|
|
||||||
min-height: 0;
|
|
||||||
}
|
|
||||||
.log-column {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
flex: 1;
|
|
||||||
min-width: 0;
|
|
||||||
transition: flex 0.3s ease;
|
|
||||||
}
|
|
||||||
.log-column.minimized {
|
|
||||||
flex: 0.1;
|
|
||||||
max-width: 50px;
|
|
||||||
border: 1px solid #777;
|
|
||||||
color: green;
|
|
||||||
}
|
|
||||||
.log-controls {
|
|
||||||
display: grid;
|
|
||||||
grid-template-columns: 1fr auto;
|
|
||||||
gap: 0.5em;
|
|
||||||
margin-bottom: 0.5em;
|
|
||||||
}
|
|
||||||
.log-controls input {
|
|
||||||
width: 100%;
|
|
||||||
padding: 4px;
|
|
||||||
}
|
|
||||||
.log-controls input:focus {
|
|
||||||
outline: none;
|
|
||||||
}
|
|
||||||
.log-stream {
|
|
||||||
flex: 1;
|
|
||||||
padding: 1em;
|
|
||||||
background: #f4f4f4;
|
|
||||||
overflow-y: auto;
|
|
||||||
white-space: pre-wrap;
|
|
||||||
word-wrap: break-word;
|
|
||||||
min-height: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.regex-error {
|
|
||||||
background-color: #ff0000 !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Make headers clickable and show pointer cursor */
|
|
||||||
h2 {
|
|
||||||
cursor: pointer;
|
|
||||||
user-select: none;
|
|
||||||
margin: 0 0 0.5em 0;
|
|
||||||
padding: 0.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
h2:hover {
|
|
||||||
background-color: rgba(0, 0, 0, 0.05);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Dark mode styles */
|
|
||||||
@media (prefers-color-scheme: dark) {
|
|
||||||
body {
|
|
||||||
background-color: #333;
|
|
||||||
color: #fff;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-stream {
|
|
||||||
background: #444;
|
|
||||||
color: #fff;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-controls input {
|
|
||||||
background: #555;
|
|
||||||
color: #fff;
|
|
||||||
border: 1px solid #777;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-controls button {
|
|
||||||
background: #555;
|
|
||||||
color: #fff;
|
|
||||||
border: 1px solid #777;
|
|
||||||
}
|
|
||||||
|
|
||||||
h2:hover {
|
|
||||||
background-color: rgba(255, 255, 255, 0.1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Hide content when minimized */
|
|
||||||
.log-column.minimized .log-controls,
|
|
||||||
.log-column.minimized .log-stream {
|
|
||||||
display: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-column.minimized h2 {
|
|
||||||
writing-mode: vertical-rl;
|
|
||||||
text-orientation: mixed;
|
|
||||||
transform: rotate(180deg);
|
|
||||||
white-space: nowrap;
|
|
||||||
margin: auto;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="log-container">
|
|
||||||
<div class="log-column">
|
|
||||||
<h2>Proxy Logs</h2>
|
|
||||||
<div class="log-controls">
|
|
||||||
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
|
|
||||||
<button id="proxy-clear-button">clear</button>
|
|
||||||
</div>
|
|
||||||
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
|
|
||||||
</div>
|
|
||||||
<div class="log-column minimized">
|
|
||||||
<h2>Upstream Logs</h2>
|
|
||||||
<div class="log-controls">
|
|
||||||
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
|
|
||||||
<button id="upstream-clear-button">clear</button>
|
|
||||||
</div>
|
|
||||||
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<script>
|
|
||||||
class LogStream {
|
|
||||||
constructor(streamElement, filterInput, clearButton, endpoint) {
|
|
||||||
this.streamElement = streamElement;
|
|
||||||
this.filterInput = filterInput;
|
|
||||||
this.clearButton = clearButton;
|
|
||||||
this.endpoint = endpoint;
|
|
||||||
this.logData = "";
|
|
||||||
this.regexFilter = null;
|
|
||||||
this.eventSource = null;
|
|
||||||
|
|
||||||
this.initialize();
|
|
||||||
}
|
|
||||||
|
|
||||||
initialize() {
|
|
||||||
this.filterInput.addEventListener('input', () => this.updateFilter());
|
|
||||||
this.clearButton.addEventListener('click', () => {
|
|
||||||
this.filterInput.value = "";
|
|
||||||
this.regexFilter = null;
|
|
||||||
this.render();
|
|
||||||
});
|
|
||||||
this.setupEventSource();
|
|
||||||
}
|
|
||||||
|
|
||||||
setupEventSource() {
|
|
||||||
if (typeof(EventSource) === "undefined") {
|
|
||||||
this.logData = "SSE Not supported by this browser.";
|
|
||||||
this.render();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const connect = () => {
|
|
||||||
this.eventSource = new EventSource(this.endpoint);
|
|
||||||
|
|
||||||
this.eventSource.onmessage = (event) => {
|
|
||||||
this.logData += event.data;
|
|
||||||
this.logData = this.logData.slice(-1024 * 100);
|
|
||||||
this.render();
|
|
||||||
};
|
|
||||||
|
|
||||||
this.eventSource.onerror = (err) => {
|
|
||||||
// Close the current connection
|
|
||||||
this.eventSource.close();
|
|
||||||
|
|
||||||
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
|
|
||||||
this.render();
|
|
||||||
|
|
||||||
// Attempt to reconnect after 5 seconds
|
|
||||||
setTimeout(() => {
|
|
||||||
this.logData += "Attempting to reconnect...\n";
|
|
||||||
this.render();
|
|
||||||
connect();
|
|
||||||
}, 5000);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Initial connection
|
|
||||||
connect();
|
|
||||||
}
|
|
||||||
|
|
||||||
render() {
|
|
||||||
let content = this.logData;
|
|
||||||
|
|
||||||
if (this.regexFilter) {
|
|
||||||
const lines = content.split('\n');
|
|
||||||
const filteredLines = lines.filter(line => this.regexFilter.test(line));
|
|
||||||
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
|
|
||||||
}
|
|
||||||
|
|
||||||
this.streamElement.textContent = content;
|
|
||||||
this.streamElement.scrollTop = this.streamElement.scrollHeight;
|
|
||||||
}
|
|
||||||
|
|
||||||
updateFilter() {
|
|
||||||
const pattern = this.filterInput.value.trim();
|
|
||||||
this.filterInput.classList.remove('regex-error');
|
|
||||||
|
|
||||||
if (!pattern) {
|
|
||||||
this.regexFilter = null;
|
|
||||||
this.render();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
this.regexFilter = new RegExp(pattern);
|
|
||||||
} catch (e) {
|
|
||||||
console.error("Invalid regex pattern:", e);
|
|
||||||
this.regexFilter = null;
|
|
||||||
this.filterInput.classList.add('regex-error');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.render();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize both log streams
|
|
||||||
document.addEventListener('DOMContentLoaded', () => {
|
|
||||||
new LogStream(
|
|
||||||
document.getElementById('proxy-log-stream'),
|
|
||||||
document.getElementById('proxy-filter-input'),
|
|
||||||
document.getElementById('proxy-clear-button'),
|
|
||||||
"/logs/streamSSE/proxy"
|
|
||||||
);
|
|
||||||
|
|
||||||
new LogStream(
|
|
||||||
document.getElementById('upstream-log-stream'),
|
|
||||||
document.getElementById('upstream-filter-input'),
|
|
||||||
document.getElementById('upstream-clear-button'),
|
|
||||||
"/logs/streamSSE/upstream"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Initialize clickable headers
|
|
||||||
document.querySelectorAll('h2').forEach(header => {
|
|
||||||
header.addEventListener('click', () => {
|
|
||||||
const column = header.closest('.log-column');
|
|
||||||
column.classList.toggle('minimized');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import "embed"
|
|
||||||
|
|
||||||
//go:embed html
|
|
||||||
var htmlFiles embed.FS
|
|
||||||
|
|
||||||
func getHTMLFile(path string) ([]byte, error) {
|
|
||||||
return htmlFiles.ReadFile("html/" + path)
|
|
||||||
}
|
|
||||||
@@ -1,13 +1,95 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/ring"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// circularBuffer is a fixed-size circular byte buffer that overwrites
|
||||||
|
// oldest data when full. It provides O(1) writes and O(n) reads.
|
||||||
|
type circularBuffer struct {
|
||||||
|
data []byte // pre-allocated capacity
|
||||||
|
head int // next write position
|
||||||
|
size int // current number of bytes stored (0 to cap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCircularBuffer(capacity int) *circularBuffer {
|
||||||
|
return &circularBuffer{
|
||||||
|
data: make([]byte, capacity),
|
||||||
|
head: 0,
|
||||||
|
size: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write appends bytes to the buffer, overwriting oldest data when full.
|
||||||
|
// Data is copied into the internal buffer (not stored by reference).
|
||||||
|
func (cb *circularBuffer) Write(p []byte) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cap := len(cb.data)
|
||||||
|
|
||||||
|
// If input is larger than capacity, only keep the last cap bytes
|
||||||
|
if len(p) >= cap {
|
||||||
|
copy(cb.data, p[len(p)-cap:])
|
||||||
|
cb.head = 0
|
||||||
|
cb.size = cap
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate how much space is available from head to end of buffer
|
||||||
|
firstPart := cap - cb.head
|
||||||
|
if firstPart >= len(p) {
|
||||||
|
// All data fits without wrapping
|
||||||
|
copy(cb.data[cb.head:], p)
|
||||||
|
cb.head = (cb.head + len(p)) % cap
|
||||||
|
} else {
|
||||||
|
// Data wraps around
|
||||||
|
copy(cb.data[cb.head:], p[:firstPart])
|
||||||
|
copy(cb.data[:len(p)-firstPart], p[firstPart:])
|
||||||
|
cb.head = len(p) - firstPart
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update size
|
||||||
|
cb.size += len(p)
|
||||||
|
if cb.size > cap {
|
||||||
|
cb.size = cap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistory returns all buffered data in correct order (oldest to newest).
|
||||||
|
// Returns a new slice (copy), not a view into internal buffer.
|
||||||
|
func (cb *circularBuffer) GetHistory() []byte {
|
||||||
|
if cb.size == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]byte, cb.size)
|
||||||
|
cap := len(cb.data)
|
||||||
|
|
||||||
|
// Calculate start position (oldest data)
|
||||||
|
start := (cb.head - cb.size + cap) % cap
|
||||||
|
|
||||||
|
if start+cb.size <= cap {
|
||||||
|
// Data is contiguous, single copy
|
||||||
|
copy(result, cb.data[start:start+cb.size])
|
||||||
|
} else {
|
||||||
|
// Data wraps around, two copies
|
||||||
|
firstPart := cap - start
|
||||||
|
copy(result[:firstPart], cb.data[start:])
|
||||||
|
copy(result[firstPart:], cb.data[:cb.size-firstPart])
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
type LogLevel int
|
type LogLevel int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -15,12 +97,14 @@ const (
|
|||||||
LevelInfo
|
LevelInfo
|
||||||
LevelWarn
|
LevelWarn
|
||||||
LevelError
|
LevelError
|
||||||
|
|
||||||
|
LogBufferSize = 100 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogMonitor struct {
|
type LogMonitor struct {
|
||||||
clients map[chan []byte]bool
|
eventbus *event.Dispatcher
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
buffer *ring.Ring
|
buffer *circularBuffer
|
||||||
bufferMu sync.RWMutex
|
bufferMu sync.RWMutex
|
||||||
|
|
||||||
// typically this can be os.Stdout
|
// typically this can be os.Stdout
|
||||||
@@ -29,6 +113,9 @@ type LogMonitor struct {
|
|||||||
// logging levels
|
// logging levels
|
||||||
level LogLevel
|
level LogLevel
|
||||||
prefix string
|
prefix string
|
||||||
|
|
||||||
|
// timestamps
|
||||||
|
timeFormat string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLogMonitor() *LogMonitor {
|
func NewLogMonitor() *LogMonitor {
|
||||||
@@ -37,11 +124,12 @@ func NewLogMonitor() *LogMonitor {
|
|||||||
|
|
||||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||||
return &LogMonitor{
|
return &LogMonitor{
|
||||||
clients: make(map[chan []byte]bool),
|
eventbus: event.NewDispatcherConfig(1000),
|
||||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
buffer: nil, // lazy initialized on first Write
|
||||||
stdout: stdout,
|
stdout: stdout,
|
||||||
level: LevelInfo,
|
level: LevelInfo,
|
||||||
prefix: "",
|
prefix: "",
|
||||||
|
timeFormat: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.bufferMu.Lock()
|
w.bufferMu.Lock()
|
||||||
bufferCopy := make([]byte, len(p))
|
if w.buffer == nil {
|
||||||
copy(bufferCopy, p)
|
w.buffer = newCircularBuffer(LogBufferSize)
|
||||||
w.buffer.Value = bufferCopy
|
}
|
||||||
w.buffer = w.buffer.Next()
|
w.buffer.Write(p)
|
||||||
w.bufferMu.Unlock()
|
w.bufferMu.Unlock()
|
||||||
|
|
||||||
|
// Make a copy for broadcast to preserve immutability
|
||||||
|
bufferCopy := make([]byte, len(p))
|
||||||
|
copy(bufferCopy, p)
|
||||||
w.broadcast(bufferCopy)
|
w.broadcast(bufferCopy)
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
@@ -69,46 +160,28 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
|||||||
func (w *LogMonitor) GetHistory() []byte {
|
func (w *LogMonitor) GetHistory() []byte {
|
||||||
w.bufferMu.RLock()
|
w.bufferMu.RLock()
|
||||||
defer w.bufferMu.RUnlock()
|
defer w.bufferMu.RUnlock()
|
||||||
|
if w.buffer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.buffer.GetHistory()
|
||||||
|
}
|
||||||
|
|
||||||
var history []byte
|
// Clear releases the buffer memory, making it eligible for GC.
|
||||||
w.buffer.Do(func(p any) {
|
// The buffer will be lazily re-allocated on the next Write.
|
||||||
if p != nil {
|
func (w *LogMonitor) Clear() {
|
||||||
if content, ok := p.([]byte); ok {
|
w.bufferMu.Lock()
|
||||||
history = append(history, content...)
|
w.buffer = nil
|
||||||
}
|
w.bufferMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||||
|
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
|
||||||
|
callback(e.Data)
|
||||||
})
|
})
|
||||||
return history
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Subscribe() chan []byte {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
ch := make(chan []byte, 100)
|
|
||||||
w.clients[ch] = true
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
delete(w.clients, ch)
|
|
||||||
close(ch)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) broadcast(msg []byte) {
|
func (w *LogMonitor) broadcast(msg []byte) {
|
||||||
w.mu.RLock()
|
event.Publish(w.eventbus, LogDataEvent{Data: msg})
|
||||||
defer w.mu.RUnlock()
|
|
||||||
|
|
||||||
for client := range w.clients {
|
|
||||||
select {
|
|
||||||
case client <- msg:
|
|
||||||
default:
|
|
||||||
// If client buffer is full, skip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||||
@@ -123,12 +196,22 @@ func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
|||||||
w.level = level
|
w.level = level
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
w.timeFormat = timeFormat
|
||||||
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||||
prefix := ""
|
prefix := ""
|
||||||
if w.prefix != "" {
|
if w.prefix != "" {
|
||||||
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||||
}
|
}
|
||||||
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
|
timestamp := ""
|
||||||
|
if w.timeFormat != "" {
|
||||||
|
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) log(level LogLevel, msg string) {
|
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||||
|
|||||||
@@ -3,45 +3,38 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLogMonitor(t *testing.T) {
|
func TestLogMonitor(t *testing.T) {
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
// Test subscription
|
// A WaitGroup is used to wait for all the expected writes to complete
|
||||||
client1 := logMonitor.Subscribe()
|
var wg sync.WaitGroup
|
||||||
client2 := logMonitor.Subscribe()
|
|
||||||
|
|
||||||
defer logMonitor.Unsubscribe(client1)
|
|
||||||
defer logMonitor.Unsubscribe(client2)
|
|
||||||
|
|
||||||
client1Messages := make([]byte, 0)
|
client1Messages := make([]byte, 0)
|
||||||
client2Messages := make([]byte, 0)
|
client2Messages := make([]byte, 0)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
defer logMonitor.OnLogData(func(data []byte) {
|
||||||
wg.Add(1)
|
client1Messages = append(client1Messages, data...)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
go func() {
|
defer logMonitor.OnLogData(func(data []byte) {
|
||||||
defer wg.Done()
|
client2Messages = append(client2Messages, data...)
|
||||||
for {
|
wg.Done()
|
||||||
select {
|
})()
|
||||||
case data := <-client1:
|
|
||||||
client1Messages = append(client1Messages, data...)
|
wg.Add(6) // 2 x 3 writes
|
||||||
case data := <-client2:
|
|
||||||
client2Messages = append(client2Messages, data...)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
logMonitor.Write([]byte("1"))
|
logMonitor.Write([]byte("1"))
|
||||||
logMonitor.Write([]byte("2"))
|
logMonitor.Write([]byte("2"))
|
||||||
logMonitor.Write([]byte("3"))
|
logMonitor.Write([]byte("3"))
|
||||||
|
|
||||||
// Wait for the goroutine to finish
|
// wait for all writes to complete
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// Check the buffer
|
// Check the buffer
|
||||||
@@ -93,3 +86,231 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
|
|||||||
t.Errorf("Expected history to be %q, got %q", expected, history)
|
t.Errorf("Expected history to be %q, got %q", expected, history)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWrite_LogTimeFormat(t *testing.T) {
|
||||||
|
// Create a new LogMonitor instance
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Enable timestamps
|
||||||
|
lm.timeFormat = time.RFC3339
|
||||||
|
|
||||||
|
// Write the message to the LogMonitor
|
||||||
|
lm.Info("Hello, World!")
|
||||||
|
|
||||||
|
// Get the history from the LogMonitor
|
||||||
|
history := lm.GetHistory()
|
||||||
|
|
||||||
|
timestamp := ""
|
||||||
|
fields := strings.Fields(string(history))
|
||||||
|
if len(fields) > 0 {
|
||||||
|
timestamp = fields[0]
|
||||||
|
} else {
|
||||||
|
t.Fatalf("Cannot extract string from history")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := time.Parse(time.RFC3339, timestamp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Cannot find timestamp: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircularBuffer_WrapAround(t *testing.T) {
|
||||||
|
// Create a small buffer to test wrap-around
|
||||||
|
cb := newCircularBuffer(10)
|
||||||
|
|
||||||
|
// Write "hello" (5 bytes)
|
||||||
|
cb.Write([]byte("hello"))
|
||||||
|
if got := string(cb.GetHistory()); got != "hello" {
|
||||||
|
t.Errorf("Expected 'hello', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write "world" (5 bytes) - buffer now full
|
||||||
|
cb.Write([]byte("world"))
|
||||||
|
if got := string(cb.GetHistory()); got != "helloworld" {
|
||||||
|
t.Errorf("Expected 'helloworld', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write "12345" (5 bytes) - should overwrite "hello"
|
||||||
|
cb.Write([]byte("12345"))
|
||||||
|
if got := string(cb.GetHistory()); got != "world12345" {
|
||||||
|
t.Errorf("Expected 'world12345', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write data larger than buffer capacity
|
||||||
|
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
|
||||||
|
if got := string(cb.GetHistory()); got != "ghijklmnop" {
|
||||||
|
t.Errorf("Expected 'ghijklmnop', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
|
||||||
|
// Test empty buffer
|
||||||
|
cb := newCircularBuffer(10)
|
||||||
|
if got := cb.GetHistory(); got != nil {
|
||||||
|
t.Errorf("Expected nil for empty buffer, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test exact capacity
|
||||||
|
cb.Write([]byte("1234567890"))
|
||||||
|
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||||
|
t.Errorf("Expected '1234567890', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test write exactly at capacity boundary
|
||||||
|
cb = newCircularBuffer(10)
|
||||||
|
cb.Write([]byte("12345"))
|
||||||
|
cb.Write([]byte("67890"))
|
||||||
|
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||||
|
t.Errorf("Expected '1234567890', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogMonitor_LazyInit(t *testing.T) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Buffer should be nil before any writes
|
||||||
|
if lm.buffer != nil {
|
||||||
|
t.Error("Expected buffer to be nil before first write")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistory should return nil when buffer is nil
|
||||||
|
if got := lm.GetHistory(); got != nil {
|
||||||
|
t.Errorf("Expected nil history before first write, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write should lazily initialize the buffer
|
||||||
|
lm.Write([]byte("test"))
|
||||||
|
|
||||||
|
if lm.buffer == nil {
|
||||||
|
t.Error("Expected buffer to be initialized after write")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := string(lm.GetHistory()); got != "test" {
|
||||||
|
t.Errorf("Expected 'test', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogMonitor_Clear(t *testing.T) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Write some data
|
||||||
|
lm.Write([]byte("hello"))
|
||||||
|
if got := string(lm.GetHistory()); got != "hello" {
|
||||||
|
t.Errorf("Expected 'hello', got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear should release the buffer
|
||||||
|
lm.Clear()
|
||||||
|
|
||||||
|
if lm.buffer != nil {
|
||||||
|
t.Error("Expected buffer to be nil after Clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := lm.GetHistory(); got != nil {
|
||||||
|
t.Errorf("Expected nil history after Clear, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogMonitor_ClearAndReuse(t *testing.T) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
|
// Write, clear, then write again
|
||||||
|
lm.Write([]byte("first"))
|
||||||
|
lm.Clear()
|
||||||
|
lm.Write([]byte("second"))
|
||||||
|
|
||||||
|
if got := string(lm.GetHistory()); got != "second" {
|
||||||
|
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogMonitorWrite(b *testing.B) {
|
||||||
|
// Test data of varying sizes
|
||||||
|
smallMsg := []byte("small message\n")
|
||||||
|
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
|
||||||
|
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
|
||||||
|
|
||||||
|
b.Run("SmallWrite", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.Write(smallMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("MediumWrite", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.Write(mediumMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("LargeWrite", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.Write(largeMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WithSubscribers", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
// Add some subscribers
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
lm.OnLogData(func(data []byte) {})
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.Write(mediumMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("GetHistory", func(b *testing.B) {
|
||||||
|
lm := NewLogMonitorWriter(io.Discard)
|
||||||
|
// Pre-populate with data
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
lm.Write(mediumMsg)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lm.GetHistory()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Benchmark Results - MBP M1 Pro
|
||||||
|
|
||||||
|
Before (ring.Ring):
|
||||||
|
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||||
|
|---------------------------------|------------|----------|-----------|
|
||||||
|
| SmallWrite (14B) | 43 ns | 40 B | 2 |
|
||||||
|
| MediumWrite (241B) | 76 ns | 264 B | 2 |
|
||||||
|
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
|
||||||
|
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
|
||||||
|
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
|
||||||
|
|
||||||
|
After (circularBuffer 10KB):
|
||||||
|
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||||
|
|---------------------------------|------------|----------|-----------|
|
||||||
|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||||
|
| MediumWrite (241B) | 67 ns | 240 B | 1 |
|
||||||
|
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
|
||||||
|
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
|
||||||
|
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
|
||||||
|
|
||||||
|
After (circularBuffer 100KB):
|
||||||
|
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||||
|
|---------------------------------|------------|-----------|-----------|
|
||||||
|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||||
|
| MediumWrite (241B) | 66 ns | 240 B | 1 |
|
||||||
|
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
|
||||||
|
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
|
||||||
|
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
|
||||||
|
|
||||||
|
Summary:
|
||||||
|
- GetHistory: 139x faster (10KB), 18x faster (100KB)
|
||||||
|
- Allocations: reduced from 2 to 1 across all operations
|
||||||
|
- Small/medium writes: ~1.1-1.6x faster
|
||||||
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,515 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/flate"
|
||||||
|
"compress/gzip"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||||
|
type TokenMetrics struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
CachedTokens int `json:"cache_tokens"`
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
|
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||||
|
DurationMs int `json:"duration_ms"`
|
||||||
|
HasCapture bool `json:"has_capture"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReqRespCapture struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
ReqPath string `json:"req_path"`
|
||||||
|
ReqHeaders map[string]string `json:"req_headers"`
|
||||||
|
ReqBody []byte `json:"req_body"`
|
||||||
|
RespHeaders map[string]string `json:"resp_headers"`
|
||||||
|
RespBody []byte `json:"resp_body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the approximate memory usage of this capture in bytes
|
||||||
|
func (c *ReqRespCapture) Size() int {
|
||||||
|
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
|
||||||
|
for k, v := range c.ReqHeaders {
|
||||||
|
size += len(k) + len(v)
|
||||||
|
}
|
||||||
|
for k, v := range c.RespHeaders {
|
||||||
|
size += len(k) + len(v)
|
||||||
|
}
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenMetricsEvent represents a token metrics event
|
||||||
|
type TokenMetricsEvent struct {
|
||||||
|
Metrics TokenMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e TokenMetricsEvent) Type() uint32 {
|
||||||
|
return TokenMetricsEventID // defined in events.go
|
||||||
|
}
|
||||||
|
|
||||||
|
// metricsMonitor parses llama-server output for token statistics
|
||||||
|
type metricsMonitor struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
metrics []TokenMetrics
|
||||||
|
maxMetrics int
|
||||||
|
nextID int
|
||||||
|
logger *LogMonitor
|
||||||
|
|
||||||
|
// capture fields
|
||||||
|
enableCaptures bool
|
||||||
|
captures map[int]ReqRespCapture // map for O(1) lookup by ID
|
||||||
|
captureOrder []int // track insertion order for FIFO eviction
|
||||||
|
captureSize int // current total size in bytes
|
||||||
|
maxCaptureSize int // max bytes for captures
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||||
|
// capture buffer size in megabytes; 0 disables captures.
|
||||||
|
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||||
|
return &metricsMonitor{
|
||||||
|
logger: logger,
|
||||||
|
maxMetrics: maxMetrics,
|
||||||
|
enableCaptures: captureBufferMB > 0,
|
||||||
|
captures: make(map[int]ReqRespCapture),
|
||||||
|
captureOrder: make([]int, 0),
|
||||||
|
captureSize: 0,
|
||||||
|
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addMetrics adds a new metric to the collection and publishes an event.
|
||||||
|
// Returns the assigned metric ID.
|
||||||
|
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
metric.ID = mp.nextID
|
||||||
|
mp.nextID++
|
||||||
|
mp.metrics = append(mp.metrics, metric)
|
||||||
|
if len(mp.metrics) > mp.maxMetrics {
|
||||||
|
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||||
|
}
|
||||||
|
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||||
|
return metric.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||||
|
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize.
|
||||||
|
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||||
|
if !mp.enableCaptures {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
captureSize := capture.Size()
|
||||||
|
if captureSize > mp.maxCaptureSize {
|
||||||
|
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict oldest (FIFO) until room available
|
||||||
|
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||||
|
oldestID := mp.captureOrder[0]
|
||||||
|
mp.captureOrder = mp.captureOrder[1:]
|
||||||
|
if evicted, exists := mp.captures[oldestID]; exists {
|
||||||
|
mp.captureSize -= evicted.Size()
|
||||||
|
delete(mp.captures, oldestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.captures[capture.ID] = capture
|
||||||
|
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||||
|
mp.captureSize += captureSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCaptureByID returns a capture by its ID, or nil if not found.
|
||||||
|
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
if capture, exists := mp.captures[id]; exists {
|
||||||
|
return &capture
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetrics returns a copy of the current metrics
|
||||||
|
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]TokenMetrics, len(mp.metrics))
|
||||||
|
copy(result, mp.metrics)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetricsJSON returns metrics as JSON
|
||||||
|
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
return json.Marshal(mp.metrics)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapHandler wraps the proxy handler to extract token metrics
|
||||||
|
// if wrapHandler returns an error it is safe to assume that no
|
||||||
|
// data was sent to the client
|
||||||
|
func (mp *metricsMonitor) wrapHandler(
|
||||||
|
modelID string,
|
||||||
|
writer gin.ResponseWriter,
|
||||||
|
request *http.Request,
|
||||||
|
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||||
|
) error {
|
||||||
|
// Capture request body and headers if captures enabled
|
||||||
|
var reqBody []byte
|
||||||
|
var reqHeaders map[string]string
|
||||||
|
if mp.enableCaptures {
|
||||||
|
if request.Body != nil {
|
||||||
|
var err error
|
||||||
|
reqBody, err = io.ReadAll(request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||||
|
}
|
||||||
|
request.Body.Close()
|
||||||
|
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||||
|
}
|
||||||
|
reqHeaders = make(map[string]string)
|
||||||
|
for key, values := range request.Header {
|
||||||
|
if len(values) > 0 {
|
||||||
|
reqHeaders[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
redactHeaders(reqHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := newBodyCopier(writer)
|
||||||
|
|
||||||
|
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||||
|
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||||
|
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := next(modelID, recorder, request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// after this point we have to assume that data was sent to the client
|
||||||
|
// and we can only log errors but not send them to clients
|
||||||
|
|
||||||
|
if recorder.Status() != http.StatusOK {
|
||||||
|
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize default metrics - these will always be recorded
|
||||||
|
tm := TokenMetrics{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: modelID,
|
||||||
|
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
body := recorder.body.Bytes()
|
||||||
|
if len(body) == 0 {
|
||||||
|
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||||
|
mp.addMetrics(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompress if needed
|
||||||
|
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||||
|
var err error
|
||||||
|
body, err = decompressBody(body, encoding)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
|
mp.addMetrics(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||||
|
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||||
|
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
|
} else {
|
||||||
|
tm = parsed
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if gjson.ValidBytes(body) {
|
||||||
|
parsed := gjson.ParseBytes(body)
|
||||||
|
usage := parsed.Get("usage")
|
||||||
|
timings := parsed.Get("timings")
|
||||||
|
|
||||||
|
// extract timings for infill - response is an array, timings are in the last element
|
||||||
|
// see #463
|
||||||
|
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||||
|
if arr := parsed.Array(); len(arr) > 0 {
|
||||||
|
timings = arr[len(arr)-1].Get("timings")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.Exists() || timings.Exists() {
|
||||||
|
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||||
|
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||||
|
} else {
|
||||||
|
tm = parsedMetrics
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build capture if enabled and determine if it will be stored
|
||||||
|
var capture *ReqRespCapture
|
||||||
|
if mp.enableCaptures {
|
||||||
|
respHeaders := make(map[string]string)
|
||||||
|
for key, values := range recorder.Header() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
respHeaders[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
redactHeaders(respHeaders)
|
||||||
|
delete(respHeaders, "Content-Encoding")
|
||||||
|
capture = &ReqRespCapture{
|
||||||
|
ReqPath: request.URL.Path,
|
||||||
|
ReqHeaders: reqHeaders,
|
||||||
|
ReqBody: reqBody,
|
||||||
|
RespHeaders: respHeaders,
|
||||||
|
RespBody: body,
|
||||||
|
}
|
||||||
|
// Only set HasCapture if the capture will actually be stored (not too large)
|
||||||
|
if capture.Size() <= mp.maxCaptureSize {
|
||||||
|
tm.HasCapture = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
metricID := mp.addMetrics(tm)
|
||||||
|
|
||||||
|
// Store capture if enabled
|
||||||
|
if capture != nil {
|
||||||
|
capture.ID = metricID
|
||||||
|
mp.addCapture(*capture)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
|
||||||
|
// Iterate **backwards** through the body looking for the data payload with
|
||||||
|
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
||||||
|
|
||||||
|
// Start from the end of the body and scan backwards for newlines
|
||||||
|
pos := len(body)
|
||||||
|
for pos > 0 {
|
||||||
|
// Find the previous newline (or start of body)
|
||||||
|
lineStart := bytes.LastIndexByte(body[:pos], '\n')
|
||||||
|
if lineStart == -1 {
|
||||||
|
lineStart = 0
|
||||||
|
} else {
|
||||||
|
lineStart++ // Move past the newline
|
||||||
|
}
|
||||||
|
|
||||||
|
line := bytes.TrimSpace(body[lineStart:pos])
|
||||||
|
pos = lineStart - 1 // Move position before the newline for next iteration
|
||||||
|
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE payload always follows "data:"
|
||||||
|
prefix := []byte("data:")
|
||||||
|
if !bytes.HasPrefix(line, prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[len(prefix):])
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
// [DONE] line itself contains nothing of interest.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.ValidBytes(data) {
|
||||||
|
parsed := gjson.ParseBytes(data)
|
||||||
|
usage := parsed.Get("usage")
|
||||||
|
timings := parsed.Get("timings")
|
||||||
|
|
||||||
|
if usage.Exists() || timings.Exists() {
|
||||||
|
return parseMetrics(modelID, start, usage, timings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
||||||
|
// default values
|
||||||
|
cachedTokens := -1 // unknown or missing data
|
||||||
|
outputTokens := 0
|
||||||
|
inputTokens := 0
|
||||||
|
|
||||||
|
// timings data
|
||||||
|
tokensPerSecond := -1.0
|
||||||
|
promptPerSecond := -1.0
|
||||||
|
durationMs := int(time.Since(start).Milliseconds())
|
||||||
|
|
||||||
|
if usage.Exists() {
|
||||||
|
if pt := usage.Get("prompt_tokens"); pt.Exists() {
|
||||||
|
// v1/chat/completions
|
||||||
|
inputTokens = int(pt.Int())
|
||||||
|
} else if it := usage.Get("input_tokens"); it.Exists() {
|
||||||
|
// v1/messages
|
||||||
|
inputTokens = int(it.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct := usage.Get("completion_tokens"); ct.Exists() {
|
||||||
|
// v1/chat/completions
|
||||||
|
outputTokens = int(ct.Int())
|
||||||
|
} else if ot := usage.Get("output_tokens"); ot.Exists() {
|
||||||
|
outputTokens = int(ot.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
|
||||||
|
cachedTokens = int(ct.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||||
|
if timings.Exists() {
|
||||||
|
inputTokens = int(timings.Get("prompt_n").Int())
|
||||||
|
outputTokens = int(timings.Get("predicted_n").Int())
|
||||||
|
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||||
|
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||||
|
durationMs = int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||||
|
|
||||||
|
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||||
|
cachedTokens = int(cachedValue.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TokenMetrics{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: modelID,
|
||||||
|
CachedTokens: cachedTokens,
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
PromptPerSecond: promptPerSecond,
|
||||||
|
TokensPerSecond: tokensPerSecond,
|
||||||
|
DurationMs: durationMs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompressBody decompresses the body based on Content-Encoding header
|
||||||
|
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||||
|
case "gzip":
|
||||||
|
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
return io.ReadAll(reader)
|
||||||
|
case "deflate":
|
||||||
|
reader := flate.NewReader(bytes.NewReader(body))
|
||||||
|
defer reader.Close()
|
||||||
|
return io.ReadAll(reader)
|
||||||
|
default:
|
||||||
|
return body, nil // Return as-is for unknown/no encoding
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseBodyCopier records the response body and writes to the original response writer
|
||||||
|
// while also capturing it in a buffer for later processing
|
||||||
|
type responseBodyCopier struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
tee io.Writer
|
||||||
|
start time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||||
|
bodyBuffer := &bytes.Buffer{}
|
||||||
|
return &responseBodyCopier{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: bodyBuffer,
|
||||||
|
tee: io.MultiWriter(w, bodyBuffer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||||
|
if w.start.IsZero() {
|
||||||
|
w.start = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single write operation that writes to both the response and buffer
|
||||||
|
return w.tee.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||||
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) Header() http.Header {
|
||||||
|
return w.ResponseWriter.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) StartTime() time.Time {
|
||||||
|
return w.start
|
||||||
|
}
|
||||||
|
|
||||||
|
// sensitiveHeaders lists headers that should be redacted in captures
|
||||||
|
var sensitiveHeaders = map[string]bool{
|
||||||
|
"authorization": true,
|
||||||
|
"proxy-authorization": true,
|
||||||
|
"cookie": true,
|
||||||
|
"set-cookie": true,
|
||||||
|
"x-api-key": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||||
|
func redactHeaders(headers map[string]string) {
|
||||||
|
for key := range headers {
|
||||||
|
if sensitiveHeaders[strings.ToLower(key)] {
|
||||||
|
headers[key] = "[REDACTED]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||||
|
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||||
|
// preferences while ensuring we can parse response bodies for metrics.
|
||||||
|
func filterAcceptEncoding(acceptEncoding string) string {
|
||||||
|
if acceptEncoding == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||||
|
var filtered []string
|
||||||
|
|
||||||
|
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||||
|
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||||
|
encoding := strings.TrimSpace(strings.Split(part, ";")[0])
|
||||||
|
if supported[strings.ToLower(encoding)] {
|
||||||
|
filtered = append(filtered, strings.TrimSpace(part))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(filtered, ", ")
|
||||||
|
}
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type peerProxyMember struct {
|
||||||
|
peerID string
|
||||||
|
reverseProxy *httputil.ReverseProxy
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerProxy struct {
|
||||||
|
peers config.PeerDictionaryConfig
|
||||||
|
proxyMap map[string]*peerProxyMember
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
|
||||||
|
proxyMap := make(map[string]*peerProxyMember)
|
||||||
|
|
||||||
|
// Sort peer IDs for consistent iteration order
|
||||||
|
peerIDs := make([]string, 0, len(peers))
|
||||||
|
for peerID := range peers {
|
||||||
|
peerIDs = append(peerIDs, peerID)
|
||||||
|
}
|
||||||
|
sort.Strings(peerIDs)
|
||||||
|
|
||||||
|
// Create a shared transport with reasonable timeouts for peer connections
|
||||||
|
// these can be tuned with feedback later
|
||||||
|
peerTransport := &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second, // Connection timeout
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ResponseHeaderTimeout: 60 * time.Second, // Time to wait for response headers
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
peer := peers[peerID]
|
||||||
|
// Create reverse proxy for this peer
|
||||||
|
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||||
|
reverseProxy.Transport = peerTransport
|
||||||
|
|
||||||
|
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||||
|
originalDirector := reverseProxy.Director
|
||||||
|
reverseProxy.Director = func(req *http.Request) {
|
||||||
|
originalDirector(req)
|
||||||
|
// Ensure Host header matches target URL for remote proxying
|
||||||
|
req.Host = req.URL.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
|
resp.Header.Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||||
|
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||||
|
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||||
|
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||||
|
}
|
||||||
|
http.Error(w, errMsg, http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
pp := &peerProxyMember{
|
||||||
|
peerID: peerID,
|
||||||
|
reverseProxy: reverseProxy,
|
||||||
|
apiKey: peer.ApiKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map each model to this peer's proxy
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
if _, found := proxyMap[modelID]; found {
|
||||||
|
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
proxyMap[modelID] = pp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PeerProxy{
|
||||||
|
peers: peers,
|
||||||
|
proxyMap: proxyMap,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||||
|
_, found := p.proxyMap[modelID]
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||||
|
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||||
|
pp, found := p.proxyMap[modelID]
|
||||||
|
if !found {
|
||||||
|
return config.Filters{}
|
||||||
|
}
|
||||||
|
// Get the peer config using the peerID
|
||||||
|
peer, found := p.peers[pp.peerID]
|
||||||
|
if !found {
|
||||||
|
return config.Filters{}
|
||||||
|
}
|
||||||
|
return peer.Filters
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||||
|
return p.peers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||||
|
pp, found := p.proxyMap[model_id]
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject API key if configured for this peer
|
||||||
|
if pp.apiKey != "" {
|
||||||
|
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||||
|
request.Header.Set("x-api-key", pp.apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
pp.reverseProxy.ServeHTTP(writer, request)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,268 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||||
|
peers := config.PeerDictionaryConfig{}
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, pm)
|
||||||
|
assert.Empty(t, pm.proxyMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||||
|
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "test-key",
|
||||||
|
Models: []string{"model-a", "model-b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, pm.proxyMap, 2)
|
||||||
|
assert.True(t, pm.HasPeerModel("model-a"))
|
||||||
|
assert.True(t, pm.HasPeerModel("model-b"))
|
||||||
|
assert.False(t, pm.HasPeerModel("model-c"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||||
|
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL1,
|
||||||
|
Models: []string{"model-a", "model-b"},
|
||||||
|
},
|
||||||
|
"peer2": config.PeerConfig{
|
||||||
|
Proxy: "http://peer2.example.com:8080",
|
||||||
|
ProxyURL: proxyURL2,
|
||||||
|
Models: []string{"model-c", "model-d"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, pm.proxyMap, 4)
|
||||||
|
assert.True(t, pm.HasPeerModel("model-a"))
|
||||||
|
assert.True(t, pm.HasPeerModel("model-b"))
|
||||||
|
assert.True(t, pm.HasPeerModel("model-c"))
|
||||||
|
assert.True(t, pm.HasPeerModel("model-d"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||||
|
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||||
|
// should be mapped, and a warning should be logged
|
||||||
|
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"alpha-peer": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL1,
|
||||||
|
Models: []string{"duplicate-model"},
|
||||||
|
},
|
||||||
|
"beta-peer": config.PeerConfig{
|
||||||
|
Proxy: "http://peer2.example.com:8080",
|
||||||
|
ProxyURL: proxyURL2,
|
||||||
|
Models: []string{"duplicate-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Should only have one entry for the duplicate model
|
||||||
|
assert.Len(t, pm.proxyMap, 1)
|
||||||
|
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasPeerModel(t *testing.T) {
|
||||||
|
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: "http://peer1.example.com:8080",
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"existing-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||||
|
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||||
|
peers := config.PeerDictionaryConfig{}
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_Success(t *testing.T) {
|
||||||
|
// Create a test server to act as the peer
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("response from peer"))
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "response from peer", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||||
|
// Create a test server that checks for the Authorization header
|
||||||
|
var receivedAuthHeader string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedAuthHeader = r.Header.Get("Authorization")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "secret-api-key",
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||||
|
// Create a test server that checks for the Authorization header
|
||||||
|
var receivedAuthHeader string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedAuthHeader = r.Header.Get("Authorization")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
ApiKey: "", // No API key
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, receivedAuthHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||||
|
// Create a test server that checks the Host header
|
||||||
|
var receivedHost string
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedHost = r.Host
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// The Host header should be set to the target URL's host
|
||||||
|
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||||
|
// Create a test server that returns SSE content type
|
||||||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
proxyURL, _ := url.Parse(testServer.URL)
|
||||||
|
peers := config.PeerDictionaryConfig{
|
||||||
|
"peer1": config.PeerConfig{
|
||||||
|
Proxy: testServer.URL,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Models: []string{"test-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pm, err := NewPeerProxy(peers, testLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err = pm.ProxyRequest("test-model", w, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||||
|
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||||
|
}
|
||||||
@@ -2,17 +2,23 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"math/rand"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProcessState string
|
type ProcessState string
|
||||||
@@ -23,9 +29,6 @@ const (
|
|||||||
StateReady ProcessState = ProcessState("ready")
|
StateReady ProcessState = ProcessState("ready")
|
||||||
StateStopping ProcessState = ProcessState("stopping")
|
StateStopping ProcessState = ProcessState("stopping")
|
||||||
|
|
||||||
// failed a health check on start and will not be recovered
|
|
||||||
StateFailed ProcessState = ProcessState("failed")
|
|
||||||
|
|
||||||
// process is shutdown and will not be restarted
|
// process is shutdown and will not be restarted
|
||||||
StateShutdown ProcessState = ProcessState("shutdown")
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
)
|
)
|
||||||
@@ -38,12 +41,17 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config config.ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
|
reverseProxy *httputil.ReverseProxy
|
||||||
|
|
||||||
// for p.cmd.Wait() select { ... }
|
// PR #155 called to cancel the upstream process
|
||||||
cmdWaitChan chan error
|
cmdMutex sync.RWMutex
|
||||||
|
cancelUpstream context.CancelFunc
|
||||||
|
|
||||||
|
// closed when command exits
|
||||||
|
cmdWaitChan chan struct{}
|
||||||
|
|
||||||
processLogger *LogMonitor
|
processLogger *LogMonitor
|
||||||
proxyLogger *LogMonitor
|
proxyLogger *LogMonitor
|
||||||
@@ -51,53 +59,71 @@ type Process struct {
|
|||||||
healthCheckTimeout int
|
healthCheckTimeout int
|
||||||
healthCheckLoopInterval time.Duration
|
healthCheckLoopInterval time.Duration
|
||||||
|
|
||||||
lastRequestHandled time.Time
|
lastRequestHandledMutex sync.RWMutex
|
||||||
|
lastRequestHandled time.Time
|
||||||
|
|
||||||
stateMutex sync.RWMutex
|
stateMutex sync.RWMutex
|
||||||
state ProcessState
|
state ProcessState
|
||||||
|
|
||||||
inFlightRequests sync.WaitGroup
|
inFlightRequests sync.WaitGroup
|
||||||
|
inFlightRequestsCount atomic.Int32
|
||||||
|
|
||||||
// used to block on multiple start() calls
|
// used to block on multiple start() calls
|
||||||
waitStarting sync.WaitGroup
|
waitStarting sync.WaitGroup
|
||||||
|
|
||||||
// for managing shutdown state
|
|
||||||
shutdownCtx context.Context
|
|
||||||
shutdownCancel context.CancelFunc
|
|
||||||
|
|
||||||
// for managing concurrency limits
|
// for managing concurrency limits
|
||||||
concurrencyLimitSemaphore chan struct{}
|
concurrencyLimitSemaphore chan struct{}
|
||||||
|
|
||||||
// stop timeout waiting for graceful shutdown
|
// used for testing to override the default value
|
||||||
gracefulStopTimeout time.Duration
|
gracefulStopTimeout time.Duration
|
||||||
|
|
||||||
|
// track the number of failed starts
|
||||||
|
failedStartCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
concurrentLimit := 10
|
concurrentLimit := 10
|
||||||
if config.ConcurrencyLimit > 0 {
|
if config.ConcurrencyLimit > 0 {
|
||||||
concurrentLimit = config.ConcurrencyLimit
|
concurrentLimit = config.ConcurrencyLimit
|
||||||
} else {
|
|
||||||
proxyLogger.Debugf("Concurrency limit for model %s not set, defaulting to 10", ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup the reverse proxy.
|
||||||
|
proxyURL, err := url.Parse(config.Proxy)
|
||||||
|
if err != nil {
|
||||||
|
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reverseProxy *httputil.ReverseProxy
|
||||||
|
if proxyURL != nil {
|
||||||
|
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||||
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
|
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||||
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
|
resp.Header.Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
cmdWaitChan: make(chan error, 1),
|
reverseProxy: reverseProxy,
|
||||||
|
cancelUpstream: nil,
|
||||||
processLogger: processLogger,
|
processLogger: processLogger,
|
||||||
proxyLogger: proxyLogger,
|
proxyLogger: proxyLogger,
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||||
state: StateStopped,
|
state: StateStopped,
|
||||||
shutdownCtx: ctx,
|
|
||||||
shutdownCancel: cancel,
|
|
||||||
|
|
||||||
// concurrency limit
|
// concurrency limit
|
||||||
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||||
|
|
||||||
|
// To be removed when migration over exec.CommandContext is complete
|
||||||
// stop timeout
|
// stop timeout
|
||||||
gracefulStopTimeout: 5 * time.Second,
|
gracefulStopTimeout: 10 * time.Second,
|
||||||
|
cmdWaitChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,6 +132,20 @@ func (p *Process) LogMonitor() *LogMonitor {
|
|||||||
return p.processLogger
|
return p.processLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
|
||||||
|
func (p *Process) setLastRequestHandled(t time.Time) {
|
||||||
|
p.lastRequestHandledMutex.Lock()
|
||||||
|
defer p.lastRequestHandledMutex.Unlock()
|
||||||
|
p.lastRequestHandled = t
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
|
||||||
|
func (p *Process) getLastRequestHandled() time.Time {
|
||||||
|
p.lastRequestHandledMutex.RLock()
|
||||||
|
defer p.lastRequestHandledMutex.RUnlock()
|
||||||
|
return p.lastRequestHandled
|
||||||
|
}
|
||||||
|
|
||||||
// custom error types for swapping state
|
// custom error types for swapping state
|
||||||
var (
|
var (
|
||||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||||
@@ -129,7 +169,15 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
|
|
||||||
|
// Atomically increment waitStarting when entering StateStarting
|
||||||
|
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
|
||||||
|
if newState == StateStarting {
|
||||||
|
p.waitStarting.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||||
|
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||||
return p.state, nil
|
return p.state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,12 +187,12 @@ func isValidTransition(from, to ProcessState) bool {
|
|||||||
case StateStopped:
|
case StateStopped:
|
||||||
return to == StateStarting
|
return to == StateStarting
|
||||||
case StateStarting:
|
case StateStarting:
|
||||||
return to == StateReady || to == StateFailed || to == StateStopping
|
return to == StateReady || to == StateStopping || to == StateStopped
|
||||||
case StateReady:
|
case StateReady:
|
||||||
return to == StateStopping
|
return to == StateStopping
|
||||||
case StateStopping:
|
case StateStopping:
|
||||||
return to == StateStopped || to == StateShutdown
|
return to == StateStopped || to == StateShutdown
|
||||||
case StateFailed, StateShutdown:
|
case StateShutdown:
|
||||||
return false // No transitions allowed from these states
|
return false // No transitions allowed from these states
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -156,6 +204,15 @@ func (p *Process) CurrentState() ProcessState {
|
|||||||
return p.state
|
return p.state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// forceState forces the process state to the new state with mutex protection.
|
||||||
|
// This should only be used in exceptional cases where the normal state transition
|
||||||
|
// validation via swapState() cannot be used.
|
||||||
|
func (p *Process) forceState(newState ProcessState) {
|
||||||
|
p.stateMutex.Lock()
|
||||||
|
defer p.stateMutex.Unlock()
|
||||||
|
p.state = newState
|
||||||
|
}
|
||||||
|
|
||||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||||
// it is a private method because starting is automatic but stopping can be called
|
// it is a private method because starting is automatic but stopping can be called
|
||||||
// at any time.
|
// at any time.
|
||||||
@@ -189,39 +246,42 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.waitStarting.Add(1)
|
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
|
||||||
defer p.waitStarting.Done()
|
defer p.waitStarting.Done()
|
||||||
|
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||||
|
|
||||||
p.cmd = exec.Command(args[0], args[1:]...)
|
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||||
p.cmd.Stdout = p.processLogger
|
p.cmd.Stdout = p.processLogger
|
||||||
p.cmd.Stderr = p.processLogger
|
p.cmd.Stderr = p.processLogger
|
||||||
p.cmd.Env = p.config.Env
|
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||||
|
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||||
|
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||||
|
setProcAttributes(p.cmd)
|
||||||
|
|
||||||
|
p.cmdMutex.Lock()
|
||||||
|
p.cancelUpstream = ctxCancelUpstream
|
||||||
|
p.cmdWaitChan = make(chan struct{})
|
||||||
|
p.cmdMutex.Unlock()
|
||||||
|
|
||||||
|
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
|
|
||||||
// Set process state to failed
|
// Set process state to failed
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
|
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||||
|
p.forceState(StateStopped) // force it into a stopped state
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||||
err, curState, swapErr,
|
strings.Join(args, " "), err, curState, swapErr,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("start() failed: %v", err)
|
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture the exit error for later signalling
|
// Capture the exit error for later signalling
|
||||||
go func() {
|
go p.waitForCmd()
|
||||||
exitErr := p.cmd.Wait()
|
|
||||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
|
||||||
|
|
||||||
// there is a race condition when SIGKILL is used, p.cmd.Wait() returns, and then
|
|
||||||
// the code below fires, putting an error into cmdWaitChan. This code is to prevent this
|
|
||||||
if exitErr != nil && exitErr.Error() == "signal: killed" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.cmdWaitChan <- exitErr
|
|
||||||
}()
|
|
||||||
|
|
||||||
// One of three things can happen at this stage:
|
// One of three things can happen at this stage:
|
||||||
// 1. The command exits unexpectedly
|
// 1. The command exits unexpectedly
|
||||||
@@ -237,67 +297,38 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||||
if checkEndpoint != "none" {
|
if checkEndpoint != "none" {
|
||||||
// keep default behaviour
|
|
||||||
if checkEndpoint == "" {
|
|
||||||
checkEndpoint = "/health"
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
proxyTo := p.config.Proxy
|
||||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
|
||||||
context.Background(),
|
|
||||||
checkStartTime.Add(maxDuration),
|
|
||||||
)
|
|
||||||
defer cancelHealthCheck()
|
|
||||||
|
|
||||||
loop:
|
|
||||||
// Ready Check loop
|
// Ready Check loop
|
||||||
for {
|
for {
|
||||||
select {
|
currentState := p.CurrentState()
|
||||||
case <-checkDeadline.Done():
|
if currentState != StateStarting {
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
if currentState == StateStopped {
|
||||||
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||||
} else {
|
|
||||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
|
||||||
}
|
}
|
||||||
case <-p.shutdownCtx.Done():
|
|
||||||
return errors.New("health check interrupted due to shutdown")
|
return errors.New("health check interrupted due to shutdown")
|
||||||
case exitErr := <-p.cmdWaitChan:
|
|
||||||
if exitErr != nil {
|
|
||||||
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
|
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
|
||||||
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
|
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
|
||||||
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("upstream command exited prematurely but successfully")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
|
||||||
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
|
||||||
cancelHealthCheck()
|
|
||||||
break loop
|
|
||||||
} else {
|
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
|
||||||
endTime, _ := checkDeadline.Deadline()
|
|
||||||
ttl := time.Until(endTime)
|
|
||||||
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
|
||||||
} else {
|
|
||||||
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if time.Since(checkStartTime) > maxDuration {
|
||||||
|
p.stopCommand()
|
||||||
|
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||||
|
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
|
ttl := time.Until(checkStartTime.Add(maxDuration))
|
||||||
|
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
<-time.After(p.healthCheckLoopInterval)
|
<-time.After(p.healthCheckLoopInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -313,10 +344,12 @@ func (p *Process) start() error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for all inflight requests to complete and ticker
|
// skip the TTL check if there are inflight requests
|
||||||
p.inFlightRequests.Wait()
|
if p.inFlightRequestsCount.Load() != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
if time.Since(p.getLastRequestHandled()) > maxDuration {
|
||||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||||
p.Stop()
|
p.Stop()
|
||||||
return
|
return
|
||||||
@@ -328,6 +361,7 @@ func (p *Process) start() error {
|
|||||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||||
} else {
|
} else {
|
||||||
|
p.failedStartCount = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -351,20 +385,13 @@ func (p *Process) StopImmediately() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
|
||||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||||
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the process with a graceful exit timeout
|
p.stopCommand()
|
||||||
p.stopCommand(p.gracefulStopTimeout)
|
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
|
||||||
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||||
@@ -372,63 +399,53 @@ func (p *Process) StopImmediately() {
|
|||||||
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||||
// the StateShutdown state, it can not be started again.
|
// the StateShutdown state, it can not be started again.
|
||||||
func (p *Process) Shutdown() {
|
func (p *Process) Shutdown() {
|
||||||
p.shutdownCancel()
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
p.stopCommand(p.gracefulStopTimeout)
|
return
|
||||||
p.state = StateShutdown
|
}
|
||||||
|
|
||||||
|
p.stopCommand()
|
||||||
|
// just force it to this state since there is no recovery from shutdown
|
||||||
|
p.forceState(StateShutdown)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
func (p *Process) stopCommand() {
|
||||||
stopStartTime := time.Now()
|
stopStartTime := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||||
|
|
||||||
|
// free the buffer in processLogger so the memory can be recovered
|
||||||
|
p.processLogger.Clear()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
p.cmdMutex.RLock()
|
||||||
defer cancelTimeout()
|
cancelUpstream := p.cancelUpstream
|
||||||
|
cmdWaitChan := p.cmdWaitChan
|
||||||
|
p.cmdMutex.RUnlock()
|
||||||
|
|
||||||
if p.cmd == nil || p.cmd.Process == nil {
|
if cancelUpstream == nil {
|
||||||
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.terminateProcess(); err != nil {
|
cancelUpstream()
|
||||||
p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
|
<-cmdWaitChan
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-sigtermTimeout.Done():
|
|
||||||
p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID)
|
|
||||||
if err := p.cmd.Process.Kill(); err != nil {
|
|
||||||
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
|
|
||||||
}
|
|
||||||
case err := <-p.cmdWaitChan:
|
|
||||||
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
|
|
||||||
// because if we make it here then the cmd has been successfully running and made it
|
|
||||||
// through the health check. There is a possibility that the cmd crashed after the health check
|
|
||||||
// succeeded but that's not a case llama-swap is handling for now.
|
|
||||||
if err != nil {
|
|
||||||
if errno, ok := err.(syscall.Errno); ok {
|
|
||||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
|
||||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
|
||||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
|
||||||
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
|
||||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
|
||||||
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
|
||||||
} else {
|
|
||||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Timeout: 500 * time.Millisecond,
|
// wait a short time for a tcp connection to be established
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 500 * time.Millisecond,
|
||||||
|
}).DialContext,
|
||||||
|
},
|
||||||
|
|
||||||
|
// give a long time to respond to the health check endpoint
|
||||||
|
// after the connection is established. See issue: 276
|
||||||
|
Timeout: 5000 * time.Millisecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", healthURL, nil)
|
req, err := http.NewRequest("GET", healthURL, nil)
|
||||||
@@ -451,12 +468,18 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
if p.reverseProxy == nil {
|
||||||
|
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
requestBeginTime := time.Now()
|
requestBeginTime := time.Now()
|
||||||
var startDuration time.Duration
|
var startDuration time.Duration
|
||||||
|
|
||||||
// prevent new requests from being made while stopping or irrecoverable
|
// prevent new requests from being made while stopping or irrecoverable
|
||||||
currentState := p.CurrentState()
|
currentState := p.CurrentState()
|
||||||
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
if currentState == StateShutdown || currentState == StateStopping {
|
||||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -470,71 +493,387 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.inFlightRequests.Add(1)
|
p.inFlightRequests.Add(1)
|
||||||
|
p.inFlightRequestsCount.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
p.lastRequestHandled = time.Now()
|
p.setLastRequestHandled(time.Now())
|
||||||
|
p.inFlightRequestsCount.Add(-1)
|
||||||
p.inFlightRequests.Done()
|
p.inFlightRequests.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// for #366
|
||||||
|
// - extract streaming param from request context, should have been set by proxymanager
|
||||||
|
var srw *statusResponseWriter
|
||||||
|
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
|
||||||
// start the process on demand
|
// start the process on demand
|
||||||
if p.CurrentState() != StateReady {
|
if p.CurrentState() != StateReady {
|
||||||
|
// start a goroutine to stream loading status messages into the response writer
|
||||||
|
// add a sync so the streaming client only runs when the goroutine has exited
|
||||||
|
|
||||||
|
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||||
|
|
||||||
|
// PR #417 (no support for anthropic v1/messages yet)
|
||||||
|
isChatCompletions := strings.HasPrefix(r.URL.Path, "/v1/chat/completions")
|
||||||
|
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming && isChatCompletions {
|
||||||
|
srw = newStatusResponseWriter(p, w)
|
||||||
|
go srw.statusUpdates(swapCtx)
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
|
||||||
|
}
|
||||||
|
|
||||||
beginStartTime := time.Now()
|
beginStartTime := time.Now()
|
||||||
if err := p.start(); err != nil {
|
if err := p.start(); err != nil {
|
||||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||||
http.Error(w, errstr, http.StatusBadGateway)
|
cancelLoadCtx()
|
||||||
|
if srw != nil {
|
||||||
|
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
|
||||||
|
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
|
||||||
|
// before closing the connection. Without this, the connection would close before
|
||||||
|
// the goroutine can write its cleanup messages, causing incomplete SSE output.
|
||||||
|
srw.waitForCompletion(100 * time.Millisecond)
|
||||||
|
} else {
|
||||||
|
http.Error(w, errstr, http.StatusBadGateway)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
startDuration = time.Since(beginStartTime)
|
startDuration = time.Since(beginStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
// should trigger srw to stop sending loading events ...
|
||||||
client := &http.Client{}
|
cancelLoadCtx()
|
||||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header = r.Header.Clone()
|
|
||||||
|
|
||||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||||
if err == nil {
|
// disconnects before the response is sent
|
||||||
req.ContentLength = contentLength
|
defer func() {
|
||||||
}
|
if r := recover(); r != nil {
|
||||||
|
if r == http.ErrAbortHandler {
|
||||||
resp, err := client.Do(req)
|
p.proxyLogger.Infof("<%s> recovered from client disconnection during streaming", p.ID)
|
||||||
if err != nil {
|
} else {
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
p.proxyLogger.Infof("<%s> recovered from panic: %v", p.ID, r)
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
for k, vv := range resp.Header {
|
|
||||||
for _, v := range vv {
|
|
||||||
w.Header().Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
// faster than io.Copy when streaming
|
|
||||||
buf := make([]byte, 32*1024)
|
|
||||||
for {
|
|
||||||
n, err := resp.Body.Read(buf)
|
|
||||||
if n > 0 {
|
|
||||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if flusher, ok := w.(http.Flusher); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err == io.EOF {
|
}()
|
||||||
break
|
|
||||||
}
|
if srw != nil {
|
||||||
if err != nil {
|
// Wait for the goroutine to finish writing its final messages
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
const completionTimeout = 1 * time.Second
|
||||||
return
|
if !srw.waitForCompletion(completionTimeout) {
|
||||||
|
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||||
}
|
}
|
||||||
|
p.reverseProxy.ServeHTTP(srw, r)
|
||||||
|
} else {
|
||||||
|
p.reverseProxy.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
totalTime := time.Since(requestBeginTime)
|
totalTime := time.Since(requestBeginTime)
|
||||||
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||||
p.ID, r.RequestURI, startDuration, totalTime)
|
p.ID, r.RequestURI, startDuration, totalTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
|
||||||
|
func (p *Process) waitForCmd() {
|
||||||
|
exitErr := p.cmd.Wait()
|
||||||
|
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||||
|
|
||||||
|
if exitErr != nil {
|
||||||
|
if errno, ok := exitErr.(syscall.Errno); ok {
|
||||||
|
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||||
|
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
|
||||||
|
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||||
|
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||||
|
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||||
|
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if exitErr.Error() != "context canceled" /* this is normal */ {
|
||||||
|
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := p.CurrentState()
|
||||||
|
switch currentState {
|
||||||
|
case StateStopping:
|
||||||
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||||
|
p.forceState(StateStopped)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||||
|
p.forceState(StateStopped) // force it to be in this state
|
||||||
|
}
|
||||||
|
|
||||||
|
p.cmdMutex.Lock()
|
||||||
|
close(p.cmdWaitChan)
|
||||||
|
p.cmdMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||||
|
func (p *Process) cmdStopUpstreamProcess() error {
|
||||||
|
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
|
||||||
|
|
||||||
|
// this should never happen ...
|
||||||
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
|
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||||
|
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.CmdStop != "" {
|
||||||
|
// replace ${PID} with the pid of the process
|
||||||
|
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||||
|
if err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||||
|
|
||||||
|
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||||
|
stopCmd.Stdout = p.processLogger
|
||||||
|
stopCmd.Stderr = p.processLogger
|
||||||
|
setProcAttributes(stopCmd)
|
||||||
|
stopCmd.Env = p.cmd.Env
|
||||||
|
|
||||||
|
if err := stopCmd.Run(); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger returns the logger for this process.
|
||||||
|
func (p *Process) Logger() *LogMonitor {
|
||||||
|
return p.processLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
var loadingRemarks = []string{
|
||||||
|
"Still faster than your last standup meeting...",
|
||||||
|
"Reticulating splines...",
|
||||||
|
"Waking up the hamsters...",
|
||||||
|
"Teaching the model manners...",
|
||||||
|
"Convincing the GPU to participate...",
|
||||||
|
"Loading weights (they're heavy)...",
|
||||||
|
"Herding electrons...",
|
||||||
|
"Compiling excuses for the delay...",
|
||||||
|
"Downloading more RAM...",
|
||||||
|
"Asking the model nicely to boot up...",
|
||||||
|
"Bribing CUDA with cookies...",
|
||||||
|
"Still loading (blame VRAM)...",
|
||||||
|
"The model is fashionably late...",
|
||||||
|
"Warming up those tensors...",
|
||||||
|
"Making the neural net do push-ups...",
|
||||||
|
"Your patience is appreciated (really)...",
|
||||||
|
"Almost there (probably)...",
|
||||||
|
"Loading like it's 1999...",
|
||||||
|
"The model forgot where it put its keys...",
|
||||||
|
"Quantum tunneling through layers...",
|
||||||
|
"Negotiating with the PCIe bus...",
|
||||||
|
"Defrosting frozen parameters...",
|
||||||
|
"Teaching attention heads to focus...",
|
||||||
|
"Running the matrix (slowly)...",
|
||||||
|
"Untangling transformer blocks...",
|
||||||
|
"Calibrating the flux capacitor...",
|
||||||
|
"Spinning up the probability wheels...",
|
||||||
|
"Waiting for the GPU to wake from its nap...",
|
||||||
|
"Converting caffeine to compute...",
|
||||||
|
"Allocating virtual patience...",
|
||||||
|
"Performing arcane CUDA rituals...",
|
||||||
|
"The model is stuck in traffic...",
|
||||||
|
"Inflating embeddings...",
|
||||||
|
"Summoning computational demons...",
|
||||||
|
"Pleading with the OOM killer...",
|
||||||
|
"Calculating the meaning of life (still at 42)...",
|
||||||
|
"Training the training wheels...",
|
||||||
|
"Optimizing the optimizer...",
|
||||||
|
"Bootstrapping the bootstrapper...",
|
||||||
|
"Loading loading screen...",
|
||||||
|
"Processing processing logs...",
|
||||||
|
"Buffering buffer overflow jokes...",
|
||||||
|
"The model hit snooze...",
|
||||||
|
"Debugging the debugger...",
|
||||||
|
"Compiling the compiler...",
|
||||||
|
"Parsing the parser (meta)...",
|
||||||
|
"Tokenizing tokens...",
|
||||||
|
"Encoding the encoder...",
|
||||||
|
"Hashing hash browns...",
|
||||||
|
"Forking spoons (not forks)...",
|
||||||
|
"The model is contemplating existence...",
|
||||||
|
"Transcending dimensional barriers...",
|
||||||
|
"Invoking elder tensor gods...",
|
||||||
|
"Unfurling probability clouds...",
|
||||||
|
"Synchronizing parallel universes...",
|
||||||
|
"The GPU is having second thoughts...",
|
||||||
|
"Recalibrating reality matrices...",
|
||||||
|
"Time is an illusion, loading doubly so...",
|
||||||
|
"Convincing bits to flip themselves...",
|
||||||
|
"The model is reading its own documentation...",
|
||||||
|
}
|
||||||
|
|
||||||
|
type statusResponseWriter struct {
|
||||||
|
hasWritten bool
|
||||||
|
writer http.ResponseWriter
|
||||||
|
process *Process
|
||||||
|
wg sync.WaitGroup // Track goroutine completion
|
||||||
|
start time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
|
||||||
|
s := &statusResponseWriter{
|
||||||
|
writer: w,
|
||||||
|
process: p,
|
||||||
|
start: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Header().Set("Content-Type", "text/event-stream") // SSE
|
||||||
|
s.Header().Set("Cache-Control", "no-cache") // no-cache
|
||||||
|
s.Header().Set("Connection", "keep-alive") // keep-alive
|
||||||
|
s.WriteHeader(http.StatusOK) // send status code 200
|
||||||
|
s.sendLine("━━━━━")
|
||||||
|
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// statusUpdates sends status updates to the client while the model is loading
|
||||||
|
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||||
|
s.wg.Add(1)
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
// Recover from panics caused by client disconnection
|
||||||
|
// Note: recover() only works within the same goroutine, so we need it here
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
s.process.proxyLogger.Debugf("<%s> statusUpdates recovered from panic (likely client disconnect): %v", s.process.ID, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
duration := time.Since(s.start)
|
||||||
|
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||||
|
s.sendLine("━━━━━")
|
||||||
|
s.sendLine(" ")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create a shuffled copy of loadingRemarks
|
||||||
|
remarks := make([]string, len(loadingRemarks))
|
||||||
|
copy(remarks, loadingRemarks)
|
||||||
|
rand.Shuffle(len(remarks), func(i, j int) {
|
||||||
|
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||||
|
})
|
||||||
|
ri := 0
|
||||||
|
|
||||||
|
// Pick a random duration to send a remark
|
||||||
|
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||||
|
lastRemarkTime := time.Now()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if s.process.CurrentState() == StateReady {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's time for a snarky remark
|
||||||
|
if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||||
|
remark := remarks[ri%len(remarks)]
|
||||||
|
ri++
|
||||||
|
s.sendLine(fmt.Sprintf("\n%s", remark))
|
||||||
|
lastRemarkTime = time.Now()
|
||||||
|
// Pick a new random duration for the next remark
|
||||||
|
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||||
|
} else {
|
||||||
|
s.sendData(".")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForCompletion waits for the statusUpdates goroutine to finish
|
||||||
|
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return true
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) sendLine(line string) {
|
||||||
|
s.sendData(line + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) sendData(data string) {
|
||||||
|
// Create the proper SSE JSON structure
|
||||||
|
type Delta struct {
|
||||||
|
ReasoningContent string `json:"reasoning_content"`
|
||||||
|
}
|
||||||
|
type Choice struct {
|
||||||
|
Delta Delta `json:"delta"`
|
||||||
|
}
|
||||||
|
type SSEMessage struct {
|
||||||
|
Choices []Choice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := SSEMessage{
|
||||||
|
Choices: []Choice{
|
||||||
|
{
|
||||||
|
Delta: Delta{
|
||||||
|
ReasoningContent: data,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write SSE formatted data, panic if not able to write
|
||||||
|
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
|
||||||
|
}
|
||||||
|
s.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) Header() http.Header {
|
||||||
|
return s.writer.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) Write(data []byte) (int, error) {
|
||||||
|
return s.writer.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
if s.hasWritten {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.hasWritten = true
|
||||||
|
s.writer.WriteHeader(statusCode)
|
||||||
|
s.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) Flush() {
|
||||||
|
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
|
||||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
@@ -5,10 +5,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
|||||||
// test that the automatic start returns the expected error type
|
// test that the automatic start returns the expected error type
|
||||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||||
// Create a process configuration
|
// Create a process configuration
|
||||||
config := ModelConfig{
|
config := config.ModelConfig{
|
||||||
Cmd: "nonexistent-command",
|
Cmd: "nonexistent-command",
|
||||||
Proxy: "http://127.0.0.1:9913",
|
Proxy: "http://127.0.0.1:9913",
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
@@ -105,8 +107,8 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
|
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||||
@@ -247,18 +249,14 @@ func TestProcess_SwapState(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||||
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
|
||||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||||
|
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
|
||||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||||
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
|
||||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||||
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
|
||||||
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
|
||||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||||
@@ -328,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
|||||||
|
|
||||||
// should run and exit but interrupt the long checkHealthTimeout
|
// should run and exit but interrupt the long checkHealthTimeout
|
||||||
checkHealthTimeout := 5
|
checkHealthTimeout := 5
|
||||||
config := ModelConfig{
|
config := config.ModelConfig{
|
||||||
Cmd: "sleep 1",
|
Cmd: "sleep 1",
|
||||||
Proxy: "http://127.0.0.1:9913",
|
Proxy: "http://127.0.0.1:9913",
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
@@ -338,7 +336,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
|||||||
process.healthCheckLoopInterval = time.Second // make it faster
|
process.healthCheckLoopInterval = time.Second // make it faster
|
||||||
err := process.start()
|
err := process.start()
|
||||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||||
assert.Equal(t, process.CurrentState(), StateFailed)
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||||
@@ -397,12 +395,19 @@ func TestProcess_StopImmediately(t *testing.T) {
|
|||||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||||
// the upstream command
|
// the upstream command
|
||||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping slow test")
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("skipping SIGTERM test on Windows ")
|
||||||
|
}
|
||||||
|
|
||||||
expectedMessage := "test_sigkill"
|
expectedMessage := "test_sigkill"
|
||||||
binaryPath := getSimpleResponderPath()
|
binaryPath := getSimpleResponderPath()
|
||||||
port := getTestPort()
|
port := getTestPort()
|
||||||
|
|
||||||
config := ModelConfig{
|
conf := config.ModelConfig{
|
||||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||||
// to force the process to exit
|
// to force the process to exit
|
||||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||||
@@ -410,7 +415,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
// reduce to make testing go faster
|
// reduce to make testing go faster
|
||||||
@@ -432,7 +437,14 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
|
|
||||||
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||||
// then the unexpected EOF is sent after the kill
|
// then the unexpected EOF is sent after the kill
|
||||||
assert.Equal(t, "1unexpected EOF\n", w.Body.String())
|
if runtime.GOOS == "windows" {
|
||||||
|
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||||
|
} else {
|
||||||
|
// Upstream may be killed mid-response.
|
||||||
|
// Assert an incomplete or partial response.
|
||||||
|
assert.NotEqual(t, "12345", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
close(waitChan)
|
close(waitChan)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -443,3 +455,117 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
// the request should have been interrupted by SIGKILL
|
// the request should have been interrupted by SIGKILL
|
||||||
<-waitChan
|
<-waitChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_StopCmd(t *testing.T) {
|
||||||
|
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
} else {
|
||||||
|
conf.CmdStop = "kill -TERM ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||||
|
expectedMessage := "test_env_not_emptied"
|
||||||
|
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
// ensure that the the default config does not blank out the inherited environment
|
||||||
|
configWEnv := conf
|
||||||
|
|
||||||
|
// ensure the additiona variables are appended to the process' environment
|
||||||
|
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||||
|
|
||||||
|
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||||
|
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||||
|
|
||||||
|
process1.start()
|
||||||
|
defer process1.Stop()
|
||||||
|
process2.start()
|
||||||
|
defer process2.Stop()
|
||||||
|
|
||||||
|
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||||
|
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||||
|
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
|
||||||
|
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
|
||||||
|
// handled appropriately.
|
||||||
|
//
|
||||||
|
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
|
||||||
|
// can't copy the body. This can be caused by a client disconnecting before the full
|
||||||
|
// response is sent from some reason.
|
||||||
|
//
|
||||||
|
// bug: https://github.com/mostlygeek/llama-swap/issues/362
|
||||||
|
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
|
||||||
|
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
|
||||||
|
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
|
||||||
|
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
expectedMessage := "panic_test"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
// Start the process
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
|
||||||
|
// Create a custom ResponseWriter that simulates a client disconnect
|
||||||
|
// by panicking when Write is called after headers are sent
|
||||||
|
panicWriter := &panicOnWriteResponseWriter{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
shouldPanic: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a request that will trigger the panic
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
|
||||||
|
|
||||||
|
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
|
||||||
|
// ProxyRequest should catch and handle this panic gracefully.
|
||||||
|
process.ProxyRequest(panicWriter, req)
|
||||||
|
|
||||||
|
// If we get here, the panic was properly recovered in ProxyRequest
|
||||||
|
// The process should still be in a ready state
|
||||||
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
}
|
||||||
|
|
||||||
|
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
|
||||||
|
// to simulate a client disconnect after headers are sent
|
||||||
|
// used by: TestProcess_ReverseProxyPanicIsHandled
|
||||||
|
type panicOnWriteResponseWriter struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
shouldPanic bool
|
||||||
|
headerWritten bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
w.headerWritten = true
|
||||||
|
w.ResponseRecorder.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
if w.shouldPanic && w.headerWritten {
|
||||||
|
// Simulate the panic that httputil.ReverseProxy throws
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
}
|
||||||
|
return w.ResponseRecorder.Write(b)
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setProcAttributes sets platform-specific process attributes
|
||||||
|
func setProcAttributes(cmd *exec.Cmd) {
|
||||||
|
// No-op on Unix systems
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setProcAttributes sets platform-specific process attributes
|
||||||
|
func setProcAttributes(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
HideWindow: true,
|
||||||
|
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,12 +5,14 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProcessGroup struct {
|
type ProcessGroup struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config Config
|
config config.Config
|
||||||
id string
|
id string
|
||||||
swap bool
|
swap bool
|
||||||
exclusive bool
|
exclusive bool
|
||||||
@@ -24,7 +26,7 @@ type ProcessGroup struct {
|
|||||||
lastUsedProcess string
|
lastUsedProcess string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||||
groupConfig, ok := config.Groups[id]
|
groupConfig, ok := config.Groups[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Unable to find configuration for group id: " + id)
|
panic("Unable to find configuration for group id: " + id)
|
||||||
@@ -44,7 +46,8 @@ func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstream
|
|||||||
// Create a Process for each member in the group
|
// Create a Process for each member in the group
|
||||||
for _, modelID := range groupConfig.Members {
|
for _, modelID := range groupConfig.Members {
|
||||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||||
|
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||||
pg.processes[modelID] = process
|
pg.processes[modelID] = process
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,10 +63,20 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
|
|||||||
if pg.swap {
|
if pg.swap {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
if pg.lastUsedProcess != modelID {
|
if pg.lastUsedProcess != modelID {
|
||||||
|
|
||||||
|
// is there something already running?
|
||||||
if pg.lastUsedProcess != "" {
|
if pg.lastUsedProcess != "" {
|
||||||
pg.processes[pg.lastUsedProcess].Stop()
|
pg.processes[pg.lastUsedProcess].Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wait for the request to the new model to be fully handled
|
||||||
|
// and prevent race conditions see issue #277
|
||||||
|
pg.processes[modelID].ProxyRequest(writer, request)
|
||||||
pg.lastUsedProcess = modelID
|
pg.lastUsedProcess = modelID
|
||||||
|
|
||||||
|
// short circuit and exit
|
||||||
|
pg.Unlock()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
pg.Unlock()
|
pg.Unlock()
|
||||||
}
|
}
|
||||||
@@ -76,6 +89,36 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
|||||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||||
|
if pg.HasMember(modelName) {
|
||||||
|
return pg.processes[modelName], true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||||
|
pg.Lock()
|
||||||
|
|
||||||
|
process, exists := pg.processes[modelID]
|
||||||
|
if !exists {
|
||||||
|
pg.Unlock()
|
||||||
|
return fmt.Errorf("process not found for %s", modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pg.lastUsedProcess == modelID {
|
||||||
|
pg.lastUsedProcess = ""
|
||||||
|
}
|
||||||
|
pg.Unlock()
|
||||||
|
|
||||||
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
defer pg.Unlock()
|
defer pg.Unlock()
|
||||||
|
|||||||
@@ -4,21 +4,23 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
"model4": getTestSimpleResponderConfig("model4"),
|
"model4": getTestSimpleResponderConfig("model4"),
|
||||||
"model5": getTestSimpleResponderConfig("model5"),
|
"model5": getTestSimpleResponderConfig("model5"),
|
||||||
},
|
},
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
"G1": {
|
"G1": {
|
||||||
Swap: true,
|
Swap: true,
|
||||||
Exclusive: true,
|
Exclusive: true,
|
||||||
@@ -33,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
|||||||
})
|
})
|
||||||
|
|
||||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||||
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||||
assert.True(t, pg.HasMember("model5"))
|
assert.True(t, pg.HasMember("model5"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,32 +46,53 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
assert.False(t, pg.HasMember("model3"))
|
assert.False(t, pg.HasMember("model3"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||||
|
// and multiple requests are made in parallel, only one process is running at a time.
|
||||||
|
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping slow test")
|
||||||
|
}
|
||||||
|
|
||||||
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
// use the same listening so if a model is already running, it will fail
|
||||||
|
// this is a way to test that swap isolation is working
|
||||||
|
// properly when there are parallel requests made at the
|
||||||
|
// same time.
|
||||||
|
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||||
|
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||||
|
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||||
|
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||||
|
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||||
|
},
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Add(len(tests))
|
||||||
for _, modelName := range tests {
|
for _, modelName := range tests {
|
||||||
t.Run(modelName, func(t *testing.T) {
|
go func(modelName string) {
|
||||||
reqBody := `{"x", "y"}`
|
defer wg.Done()
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
}(modelName)
|
||||||
// make sure only one process is in the running state
|
|
||||||
count := 0
|
|
||||||
for _, process := range pg.processes {
|
|
||||||
if process.CurrentState() == StateReady {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.Equal(t, 1, count)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
@@ -15,6 +16,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -23,10 +26,12 @@ const (
|
|||||||
PROFILE_SPLIT_CHAR = ":"
|
PROFILE_SPLIT_CHAR = ":"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type proxyCtxKey string
|
||||||
|
|
||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config Config
|
config config.Config
|
||||||
ginEngine *gin.Engine
|
ginEngine *gin.Engine
|
||||||
|
|
||||||
// logging
|
// logging
|
||||||
@@ -34,20 +39,54 @@ type ProxyManager struct {
|
|||||||
upstreamLogger *LogMonitor
|
upstreamLogger *LogMonitor
|
||||||
muxLogger *LogMonitor
|
muxLogger *LogMonitor
|
||||||
|
|
||||||
|
metricsMonitor *metricsMonitor
|
||||||
|
|
||||||
processGroups map[string]*ProcessGroup
|
processGroups map[string]*ProcessGroup
|
||||||
|
|
||||||
|
// shutdown signaling
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownCancel context.CancelFunc
|
||||||
|
|
||||||
|
// version info
|
||||||
|
buildDate string
|
||||||
|
commit string
|
||||||
|
version string
|
||||||
|
|
||||||
|
// peer proxy see: #296, #433
|
||||||
|
peerProxy *PeerProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config Config) *ProxyManager {
|
func New(proxyConfig config.Config) *ProxyManager {
|
||||||
// set up loggers
|
// set up loggers
|
||||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
|
||||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
|
||||||
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
|
||||||
|
|
||||||
if config.LogRequests {
|
var muxLogger, upstreamLogger, proxyLogger *LogMonitor
|
||||||
|
switch proxyConfig.LogToStdout {
|
||||||
|
case config.LogToStdoutNone:
|
||||||
|
muxLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
case config.LogToStdoutBoth:
|
||||||
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||||
|
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||||
|
case config.LogToStdoutUpstream:
|
||||||
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
upstreamLogger = NewLogMonitorWriter(muxLogger)
|
||||||
|
proxyLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
default:
|
||||||
|
// same as config.LogToStdoutProxy
|
||||||
|
// helpful because some old tests create a config.Config directly and it
|
||||||
|
// may not have LogToStdout set explicitly
|
||||||
|
muxLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
upstreamLogger = NewLogMonitorWriter(io.Discard)
|
||||||
|
proxyLogger = NewLogMonitorWriter(muxLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if proxyConfig.LogRequests {
|
||||||
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
|
||||||
case "debug":
|
case "debug":
|
||||||
proxyLogger.SetLogLevel(LevelDebug)
|
proxyLogger.SetLogLevel(LevelDebug)
|
||||||
upstreamLogger.SetLogLevel(LevelDebug)
|
upstreamLogger.SetLogLevel(LevelDebug)
|
||||||
@@ -65,29 +104,123 @@ func New(config Config) *ProxyManager {
|
|||||||
upstreamLogger.SetLogLevel(LevelInfo)
|
upstreamLogger.SetLogLevel(LevelInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// see: https://go.dev/src/time/format.go
|
||||||
|
timeFormats := map[string]string{
|
||||||
|
"ansic": time.ANSIC,
|
||||||
|
"unixdate": time.UnixDate,
|
||||||
|
"rubydate": time.RubyDate,
|
||||||
|
"rfc822": time.RFC822,
|
||||||
|
"rfc822z": time.RFC822Z,
|
||||||
|
"rfc850": time.RFC850,
|
||||||
|
"rfc1123": time.RFC1123,
|
||||||
|
"rfc1123z": time.RFC1123Z,
|
||||||
|
"rfc3339": time.RFC3339,
|
||||||
|
"rfc3339nano": time.RFC3339Nano,
|
||||||
|
"kitchen": time.Kitchen,
|
||||||
|
"stamp": time.Stamp,
|
||||||
|
"stampmilli": time.StampMilli,
|
||||||
|
"stampmicro": time.StampMicro,
|
||||||
|
"stampnano": time.StampNano,
|
||||||
|
}
|
||||||
|
|
||||||
|
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
|
||||||
|
proxyLogger.SetLogTimeFormat(timeFormat)
|
||||||
|
upstreamLogger.SetLogTimeFormat(timeFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
var maxMetrics int
|
||||||
|
if proxyConfig.MetricsMaxInMemory <= 0 {
|
||||||
|
maxMetrics = 1000 // Default fallback
|
||||||
|
} else {
|
||||||
|
maxMetrics = proxyConfig.MetricsMaxInMemory
|
||||||
|
}
|
||||||
|
|
||||||
|
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
||||||
|
if err != nil {
|
||||||
|
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
||||||
|
peerProxy = nil
|
||||||
|
}
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: proxyConfig,
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
|
|
||||||
proxyLogger: proxyLogger,
|
proxyLogger: proxyLogger,
|
||||||
muxLogger: stdoutLogger,
|
muxLogger: muxLogger,
|
||||||
upstreamLogger: upstreamLogger,
|
upstreamLogger: upstreamLogger,
|
||||||
|
|
||||||
|
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics, proxyConfig.CaptureBuffer),
|
||||||
|
|
||||||
processGroups: make(map[string]*ProcessGroup),
|
processGroups: make(map[string]*ProcessGroup),
|
||||||
|
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownCancel: shutdownCancel,
|
||||||
|
|
||||||
|
buildDate: "unknown",
|
||||||
|
commit: "abcd1234",
|
||||||
|
version: "0",
|
||||||
|
|
||||||
|
peerProxy: peerProxy,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the process groups
|
// create the process groups
|
||||||
for groupID := range config.Groups {
|
for groupID := range proxyConfig.Groups {
|
||||||
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
|
||||||
pm.processGroups[groupID] = processGroup
|
pm.processGroups[groupID] = processGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.setupGinEngine()
|
pm.setupGinEngine()
|
||||||
|
|
||||||
|
// run any startup hooks
|
||||||
|
if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
// do it in the background, don't block startup -- not sure if good idea yet
|
||||||
|
go func() {
|
||||||
|
discardWriter := &DiscardWriter{}
|
||||||
|
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||||
|
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
event.Emit(ModelPreloadedEvent{
|
||||||
|
ModelName: modelID,
|
||||||
|
Success: false,
|
||||||
|
})
|
||||||
|
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
req, _ := http.NewRequest("GET", "/", nil)
|
||||||
|
processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||||
|
event.Emit(ModelPreloadedEvent{
|
||||||
|
ModelName: modelID,
|
||||||
|
Success: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
return pm
|
return pm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) setupGinEngine() {
|
func (pm *ProxyManager) setupGinEngine() {
|
||||||
|
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
|
|
||||||
|
// don't log the Wake on Lan proxy health check
|
||||||
|
if c.Request.URL.Path == "/wol-health" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Start timer
|
// Start timer
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
@@ -142,59 +275,116 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
// Protected routes use pm.apiKeyAuth() middleware
|
||||||
|
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||||
|
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
// Support anthropic count_tokens API (Also added in the above PR)
|
||||||
|
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// Support embeddings
|
// Support embeddings and reranking
|
||||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
|
||||||
|
// llama-server's /reranking endpoint + aliases
|
||||||
|
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
|
// llama-server's /infill endpoint for code infilling
|
||||||
|
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
|
// llama-server's /completion endpoint
|
||||||
|
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.GET("/v1/audio/voices", pm.apiKeyAuth(), pm.proxyGETModelHandler)
|
||||||
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
|
||||||
|
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||||
|
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||||
|
|
||||||
// in proxymanager_loghandlers.go
|
// in proxymanager_loghandlers.go
|
||||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
|
||||||
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* User Interface Endpoints
|
||||||
|
*/
|
||||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||||
// Set the Content-Type header to text/html
|
c.Redirect(http.StatusFound, "/ui")
|
||||||
c.Header("Content-Type", "text/html")
|
})
|
||||||
|
|
||||||
// Write the embedded HTML content to the response
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||||
htmlData, err := getHTMLFile("index.html")
|
c.Redirect(http.StatusFound, "/ui/models")
|
||||||
if err != nil {
|
})
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.proxyToUpstream)
|
||||||
return
|
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
||||||
}
|
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
||||||
_, err = c.Writer.Write(htmlData)
|
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||||
if err != nil {
|
c.String(http.StatusOK, "OK")
|
||||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
})
|
||||||
return
|
|
||||||
}
|
// see cmd/wol-proxy/wol-proxy.go, not logged
|
||||||
|
pm.ginEngine.GET("/wol-health", func(c *gin.Context) {
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
})
|
})
|
||||||
|
|
||||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||||
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
||||||
c.Data(http.StatusOK, "image/x-icon", data)
|
c.Data(http.StatusOK, "image/x-icon", data)
|
||||||
} else {
|
} else {
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
reactFS, err := GetReactFS()
|
||||||
|
if err != nil {
|
||||||
|
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
||||||
|
} else {
|
||||||
|
// Serve files with compression support under /ui/*
|
||||||
|
// This handler checks for pre-compressed .br and .gz files
|
||||||
|
pm.ginEngine.GET("/ui/*filepath", func(c *gin.Context) {
|
||||||
|
filepath := strings.TrimPrefix(c.Param("filepath"), "/")
|
||||||
|
// Default to index.html for directory-like paths
|
||||||
|
if filepath == "" {
|
||||||
|
filepath = "index.html"
|
||||||
|
}
|
||||||
|
|
||||||
|
ServeCompressedFile(reactFS, c.Writer, c.Request, filepath)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Serve SPA for UI under /ui/* - fallback to index.html for client-side routing
|
||||||
|
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
||||||
|
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
||||||
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this looks like a file request (has extension)
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
if strings.Contains(path, ".") && !strings.HasSuffix(path, "/") {
|
||||||
|
// This was likely a file request that wasn't found
|
||||||
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve index.html for SPA routing
|
||||||
|
ServeCompressedFile(reactFS, c.Writer, c.Request, "index.html")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// see: proxymanager_api.go
|
||||||
|
// add API handler functions
|
||||||
|
addApiHandlers(pm)
|
||||||
|
|
||||||
// Disable console color for testing
|
// Disable console color for testing
|
||||||
gin.DisableConsoleColor()
|
gin.DisableConsoleColor()
|
||||||
}
|
}
|
||||||
@@ -242,18 +432,13 @@ func (pm *ProxyManager) Shutdown() {
|
|||||||
}(processGroup)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
pm.shutdownCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
||||||
// de-alias the real model name and get a real one
|
|
||||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
|
||||||
if !found {
|
|
||||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
processGroup := pm.findGroupByModelName(realModelName)
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
if processGroup == nil {
|
if processGroup == nil {
|
||||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if processGroup.exclusive {
|
if processGroup.exclusive {
|
||||||
@@ -265,83 +450,171 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return processGroup, realModelName, nil
|
return processGroup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := []interface{}{}
|
data := make([]gin.H, 0, len(pm.config.Models))
|
||||||
|
createdTime := time.Now().Unix()
|
||||||
|
|
||||||
|
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
||||||
|
record := gin.H{
|
||||||
|
"id": modelId,
|
||||||
|
"object": "model",
|
||||||
|
"created": createdTime,
|
||||||
|
"owned_by": "llama-swap",
|
||||||
|
}
|
||||||
|
|
||||||
|
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||||
|
record["name"] = name
|
||||||
|
}
|
||||||
|
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||||
|
record["description"] = desc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add metadata if present
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
record["meta"] = gin.H{
|
||||||
|
"llamaswap": modelConfig.Metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return record
|
||||||
|
}
|
||||||
|
|
||||||
for id, modelConfig := range pm.config.Models {
|
for id, modelConfig := range pm.config.Models {
|
||||||
if modelConfig.Unlisted {
|
if modelConfig.Unlisted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
data = append(data, map[string]interface{}{
|
data = append(data, newRecord(id, modelConfig))
|
||||||
"id": id,
|
|
||||||
"object": "model",
|
// Include aliases
|
||||||
"created": time.Now().Unix(),
|
if pm.config.IncludeAliasesInList {
|
||||||
"owned_by": "llama-swap",
|
for _, alias := range modelConfig.Aliases {
|
||||||
})
|
if alias := strings.TrimSpace(alias); alias != "" {
|
||||||
|
data = append(data, newRecord(alias, modelConfig))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the Content-Type header to application/json
|
if pm.peerProxy != nil {
|
||||||
c.Header("Content-Type", "application/json")
|
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||||
|
// add peer models
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
// Skip unlisted models if not showing them
|
||||||
|
record := newRecord(modelID, config.ModelConfig{
|
||||||
|
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"peerID": peerID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
data = append(data, record)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by the "id" key
|
||||||
|
sort.Slice(data, func(i, j int) bool {
|
||||||
|
si, _ := data[i]["id"].(string)
|
||||||
|
sj, _ := data[j]["id"].(string)
|
||||||
|
return si < sj
|
||||||
|
})
|
||||||
|
|
||||||
|
// Set CORS headers if origin exists
|
||||||
|
if origin := c.GetHeader("Origin"); origin != "" {
|
||||||
c.Header("Access-Control-Allow-Origin", origin)
|
c.Header("Access-Control-Allow-Origin", origin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode the data as JSON and write it to the response writer
|
// Use gin's JSON method which handles content-type and encoding
|
||||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
c.JSON(http.StatusOK, gin.H{
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
"object": "list",
|
||||||
return
|
"data": data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// findModelInPath searches for a valid model name in a path with slashes.
|
||||||
|
// It iteratively builds up path segments until it finds a matching model.
|
||||||
|
// Returns: (searchModelName, realModelName, remainingPath, found)
|
||||||
|
// Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
|
||||||
|
func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
|
||||||
|
parts := strings.Split(strings.TrimSpace(path), "/")
|
||||||
|
searchModelName := ""
|
||||||
|
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if searchModelName == "" {
|
||||||
|
searchModelName = part
|
||||||
|
} else {
|
||||||
|
searchModelName = searchModelName + "/" + part
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
||||||
|
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return "", "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||||
requestedModel := c.Param("model_id")
|
upstreamPath := c.Param("upstreamPath")
|
||||||
|
|
||||||
if requestedModel == "" {
|
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||||
|
|
||||||
|
if !modelFound {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
||||||
|
// This ensures relative URLs in upstream responses resolve correctly and
|
||||||
|
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
||||||
|
// HTTP method (301 would downgrade to GET).
|
||||||
|
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||||
|
newPath := "/upstream/" + searchModelName + "/"
|
||||||
|
if c.Request.URL.RawQuery != "" {
|
||||||
|
newPath += "?" + c.Request.URL.RawQuery
|
||||||
|
}
|
||||||
|
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||||
|
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||||
|
} else {
|
||||||
|
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
c.Request.URL.Path = c.Param("upstreamPath")
|
originalPath := c.Request.URL.Path
|
||||||
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
c.Request.URL.Path = remainingPath
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
// attempt to record metrics if it is a POST request
|
||||||
var html strings.Builder
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||||
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||||
// Extract keys and sort them
|
return
|
||||||
var modelIDs []string
|
}
|
||||||
for modelID, modelConfig := range pm.config.Models {
|
} else {
|
||||||
if modelConfig.Unlisted {
|
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
|
||||||
continue
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
modelIDs = append(modelIDs, modelID)
|
|
||||||
}
|
}
|
||||||
sort.Strings(modelIDs)
|
|
||||||
|
|
||||||
// Iterate over sorted keys
|
|
||||||
for _, modelID := range modelIDs {
|
|
||||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
|
|
||||||
}
|
|
||||||
html.WriteString("</ul></body></html>")
|
|
||||||
c.Header("Content-Type", "text/html")
|
|
||||||
c.String(http.StatusOK, html.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||||
@@ -354,32 +627,117 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
// Look for a matching local model first
|
||||||
if err != nil {
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// issue #69 allow custom model names to be sent to upstream
|
modelID, found := pm.config.RealModelName(requestedModel)
|
||||||
useModelName := pm.config.Models[realModelName].UseModelName
|
if found {
|
||||||
if useModelName != "" {
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
|
useModelName := pm.config.Models[modelID].UseModelName
|
||||||
|
if useModelName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #174 strip parameters from the JSON body
|
||||||
|
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||||
|
if err != nil { // just log it and continue
|
||||||
|
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||||
|
} else {
|
||||||
|
for _, param := range stripParams {
|
||||||
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||||
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #453 set/override parameters in the JSON body
|
||||||
|
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||||
|
for _, key := range setParamKeys {
|
||||||
|
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
|
nextHandler = processGroup.ProxyRequest
|
||||||
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
|
modelID = requestedModel
|
||||||
|
|
||||||
|
// issue #453 apply filters for peer requests
|
||||||
|
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||||
|
|
||||||
|
// Apply stripParams - remove specified parameters from request
|
||||||
|
stripParams := peerFilters.SanitizedStripParams()
|
||||||
|
for _, param := range stripParams {
|
||||||
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||||
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply setParams - set/override specified parameters in request
|
||||||
|
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||||
|
for _, key := range setParamKeys {
|
||||||
|
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nextHandler = pm.peerProxy.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextHandler == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
c.Request.Header.Del("transfer-encoding")
|
c.Request.Header.Del("transfer-encoding")
|
||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
c.Request.ContentLength = int64(len(bodyBytes))
|
||||||
|
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
// issue #366 extract values that downstream handlers may need
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||||
return
|
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
|
||||||
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
|
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -397,9 +755,29 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
// Look for a matching local model first, then check peers
|
||||||
if err != nil {
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
var useModelName string
|
||||||
|
|
||||||
|
modelID, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if found {
|
||||||
|
processGroup, err := pm.swapProcessGroup(modelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
useModelName = pm.config.Models[modelID].UseModelName
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
|
nextHandler = processGroup.ProxyRequest
|
||||||
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
|
modelID = requestedModel
|
||||||
|
nextHandler = pm.peerProxy.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextHandler == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -415,8 +793,6 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
// If this is the model field and we have a profile, use just the model name
|
// If this is the model field and we have a profile, use just the model name
|
||||||
if key == "model" {
|
if key == "model" {
|
||||||
// # issue #69 allow custom model names to be sent to upstream
|
// # issue #69 allow custom model names to be sent to upstream
|
||||||
useModelName := pm.config.Models[realModelName].UseModelName
|
|
||||||
|
|
||||||
if useModelName != "" {
|
if useModelName != "" {
|
||||||
fieldValue = useModelName
|
fieldValue = useModelName
|
||||||
} else {
|
} else {
|
||||||
@@ -486,9 +862,46 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||||
|
|
||||||
// Use the modified request for proxying
|
// Use the modified request for proxying
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
if err := nextHandler(modelID, c.Writer, modifiedReq); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) proxyGETModelHandler(c *gin.Context) {
|
||||||
|
requestedModel := c.Query("model")
|
||||||
|
if requestedModel == "" {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing required 'model' query parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||||
|
var modelID string
|
||||||
|
|
||||||
|
if realModelID, found := pm.config.RealModelName(requestedModel); found {
|
||||||
|
processGroup, err := pm.swapProcessGroup(realModelID)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelID = realModelID
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||||
|
nextHandler = processGroup.ProxyRequest
|
||||||
|
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||||
|
modelID = requestedModel
|
||||||
|
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||||
|
nextHandler = pm.peerProxy.ProxyRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextHandler == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable handler for %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying GET Request for model %s", modelID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -503,6 +916,67 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// apiKeyAuth returns a middleware that validates API keys if configured.
|
||||||
|
// Returns a pass-through handler if no API keys are configured.
|
||||||
|
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
||||||
|
if len(pm.config.RequiredAPIKeys) == 0 {
|
||||||
|
return func(c *gin.Context) { c.Next() }
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
xApiKey := c.GetHeader("x-api-key")
|
||||||
|
|
||||||
|
var bearerKey string
|
||||||
|
var basicKey string
|
||||||
|
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||||
|
if strings.HasPrefix(auth, "Bearer ") {
|
||||||
|
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
} else if strings.HasPrefix(auth, "Basic ") {
|
||||||
|
// Basic Auth: base64(username:password), password is the API key
|
||||||
|
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||||
|
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||||
|
parts := strings.SplitN(string(decoded), ":", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
basicKey = parts[1] // password is the API key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use first key found: Basic, then Bearer, then x-api-key
|
||||||
|
var providedKey string
|
||||||
|
if basicKey != "" {
|
||||||
|
providedKey = basicKey
|
||||||
|
} else if bearerKey != "" {
|
||||||
|
providedKey = bearerKey
|
||||||
|
} else {
|
||||||
|
providedKey = xApiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate key
|
||||||
|
valid := false
|
||||||
|
for _, key := range pm.config.RequiredAPIKeys {
|
||||||
|
if providedKey == key {
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||||
|
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip auth headers to prevent leakage to upstream
|
||||||
|
c.Request.Header.Del("Authorization")
|
||||||
|
c.Request.Header.Del("x-api-key")
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
pm.StopProcesses(StopImmediately)
|
pm.StopProcesses(StopImmediately)
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
@@ -516,8 +990,13 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|||||||
for _, process := range processGroup.processes {
|
for _, process := range processGroup.processes {
|
||||||
if process.CurrentState() == StateReady {
|
if process.CurrentState() == StateReady {
|
||||||
runningProcesses = append(runningProcesses, gin.H{
|
runningProcesses = append(runningProcesses, gin.H{
|
||||||
"model": process.ID,
|
"model": process.ID,
|
||||||
"state": process.state,
|
"state": process.state,
|
||||||
|
"cmd": process.config.Cmd,
|
||||||
|
"proxy": process.config.Proxy,
|
||||||
|
"ttl": process.config.UnloadAfter,
|
||||||
|
"name": process.config.Name,
|
||||||
|
"description": process.config.Description,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -539,3 +1018,11 @@ func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) SetVersion(buildDate string, commit string, version string) {
|
||||||
|
pm.Lock()
|
||||||
|
defer pm.Unlock()
|
||||||
|
pm.buildDate = buildDate
|
||||||
|
pm.commit = commit
|
||||||
|
pm.version = version
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,271 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Unlisted bool `json:"unlisted"`
|
||||||
|
PeerID string `json:"peerID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func addApiHandlers(pm *ProxyManager) {
|
||||||
|
// Add API endpoints for React to consume
|
||||||
|
// Protected with API key authentication
|
||||||
|
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||||
|
{
|
||||||
|
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||||
|
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||||
|
apiGroup.GET("/events", pm.apiSendEvents)
|
||||||
|
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||||
|
apiGroup.GET("/version", pm.apiGetVersion)
|
||||||
|
apiGroup.GET("/captures/:id", pm.apiGetCapture)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) {
|
||||||
|
pm.StopProcesses(StopImmediately)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"msg": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) getModelStatus() []Model {
|
||||||
|
// Extract keys and sort them
|
||||||
|
models := []Model{}
|
||||||
|
|
||||||
|
modelIDs := make([]string, 0, len(pm.config.Models))
|
||||||
|
for modelID := range pm.config.Models {
|
||||||
|
modelIDs = append(modelIDs, modelID)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIDs)
|
||||||
|
|
||||||
|
// Iterate over sorted keys
|
||||||
|
for _, modelID := range modelIDs {
|
||||||
|
// Get process state
|
||||||
|
processGroup := pm.findGroupByModelName(modelID)
|
||||||
|
state := "unknown"
|
||||||
|
if processGroup != nil {
|
||||||
|
process := processGroup.processes[modelID]
|
||||||
|
if process != nil {
|
||||||
|
var stateStr string
|
||||||
|
switch process.CurrentState() {
|
||||||
|
case StateReady:
|
||||||
|
stateStr = "ready"
|
||||||
|
case StateStarting:
|
||||||
|
stateStr = "starting"
|
||||||
|
case StateStopping:
|
||||||
|
stateStr = "stopping"
|
||||||
|
case StateShutdown:
|
||||||
|
stateStr = "shutdown"
|
||||||
|
case StateStopped:
|
||||||
|
stateStr = "stopped"
|
||||||
|
default:
|
||||||
|
stateStr = "unknown"
|
||||||
|
}
|
||||||
|
state = stateStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
models = append(models, Model{
|
||||||
|
Id: modelID,
|
||||||
|
Name: pm.config.Models[modelID].Name,
|
||||||
|
Description: pm.config.Models[modelID].Description,
|
||||||
|
State: state,
|
||||||
|
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over the peer models
|
||||||
|
if pm.peerProxy != nil {
|
||||||
|
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||||
|
for _, modelID := range peer.Models {
|
||||||
|
models = append(models, Model{
|
||||||
|
Id: modelID,
|
||||||
|
PeerID: peerID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
msgTypeModelStatus messageType = "modelStatus"
|
||||||
|
msgTypeLogData messageType = "logData"
|
||||||
|
msgTypeMetrics messageType = "metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
type messageEnvelope struct {
|
||||||
|
Type messageType `json:"type"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// sends a stream of different message types that happen on the server
|
||||||
|
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
// prevent nginx from buffering SSE
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
sendBuffer := make(chan messageEnvelope, 25)
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
sendModels := func() {
|
||||||
|
data, err := json.Marshal(pm.getModelStatus())
|
||||||
|
if err == nil {
|
||||||
|
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||||
|
select {
|
||||||
|
case sendBuffer <- msg:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendLogData := func(source string, data []byte) {
|
||||||
|
data, err := json.Marshal(gin.H{
|
||||||
|
"source": source,
|
||||||
|
"data": string(data),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendMetrics := func(metrics []TokenMetrics) {
|
||||||
|
jsonData, err := json.Marshal(metrics)
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send updated models list
|
||||||
|
*/
|
||||||
|
defer event.On(func(e ProcessStateChangeEvent) {
|
||||||
|
sendModels()
|
||||||
|
})()
|
||||||
|
defer event.On(func(e ConfigFileChangedEvent) {
|
||||||
|
sendModels()
|
||||||
|
})()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send Log data
|
||||||
|
*/
|
||||||
|
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||||
|
sendLogData("proxy", data)
|
||||||
|
})()
|
||||||
|
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||||
|
sendLogData("upstream", data)
|
||||||
|
})()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send Metrics data
|
||||||
|
*/
|
||||||
|
defer event.On(func(e TokenMetricsEvent) {
|
||||||
|
sendMetrics([]TokenMetrics{e.Metrics})
|
||||||
|
})()
|
||||||
|
|
||||||
|
// send initial batch of data
|
||||||
|
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||||
|
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||||
|
sendModels()
|
||||||
|
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
case <-pm.shutdownCtx.Done():
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
case msg := <-sendBuffer:
|
||||||
|
c.SSEvent("message", msg)
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||||
|
jsonData, err := pm.metricsMonitor.getMetricsJSON()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", jsonData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||||
|
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||||
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
|
if processGroup == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, map[string]string{
|
||||||
|
"version": pm.version,
|
||||||
|
"commit": pm.commit,
|
||||||
|
"build_date": pm.buildDate,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||||
|
if capture == nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, capture)
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,20 +12,7 @@ import (
|
|||||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||||
accept := c.GetHeader("Accept")
|
accept := c.GetHeader("Accept")
|
||||||
if strings.Contains(accept, "text/html") {
|
if strings.Contains(accept, "text/html") {
|
||||||
// Set the Content-Type header to text/html
|
c.Redirect(http.StatusFound, "/ui/")
|
||||||
c.Header("Content-Type", "text/html")
|
|
||||||
|
|
||||||
// Write the embedded HTML content to the response
|
|
||||||
logsHTML, err := getHTMLFile("logs.html")
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = c.Writer.Write(logsHTML)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
history := pm.muxLogger.GetHistory()
|
history := pm.muxLogger.GetHistory()
|
||||||
@@ -40,17 +28,16 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
c.Header("Transfer-Encoding", "chunked")
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
// prevent nginx from buffering streamed logs
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
logMonitorId := c.Param("logMonitorID")
|
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
logger, err := pm.getLogger(logMonitorId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadRequest, err.Error())
|
c.String(http.StatusBadRequest, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ch := logger.Subscribe()
|
|
||||||
defer logger.Unsubscribe(ch)
|
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||||
@@ -68,75 +55,53 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream new logs
|
sendChan := make(chan []byte, 10)
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
defer logger.OnLogData(func(data []byte) {
|
||||||
|
select {
|
||||||
|
case sendChan <- data:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-ch:
|
case <-c.Request.Context().Done():
|
||||||
_, err := c.Writer.Write(msg)
|
cancel()
|
||||||
if err != nil {
|
return
|
||||||
// just break the loop if we can't write for some reason
|
case <-pm.shutdownCtx.Done():
|
||||||
return
|
cancel()
|
||||||
}
|
return
|
||||||
|
case data := <-sendChan:
|
||||||
|
c.Writer.Write(data)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
case <-notify:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
|
||||||
|
|
||||||
logMonitorId := c.Param("logMonitorID")
|
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ch := logger.Subscribe()
|
|
||||||
defer logger.Unsubscribe(ch)
|
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
|
||||||
|
|
||||||
// Send history first if not skipped
|
|
||||||
_, skipHistory := c.GetQuery("no-history")
|
|
||||||
if !skipHistory {
|
|
||||||
history := logger.GetHistory()
|
|
||||||
if len(history) != 0 {
|
|
||||||
c.SSEvent("message", string(history))
|
|
||||||
c.Writer.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream new logs
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case msg := <-ch:
|
|
||||||
c.SSEvent("message", string(msg))
|
|
||||||
c.Writer.Flush()
|
|
||||||
case <-notify:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||||
var logger *LogMonitor
|
switch logMonitorId {
|
||||||
|
case "":
|
||||||
if logMonitorId == "" {
|
|
||||||
// maintain the default
|
// maintain the default
|
||||||
logger = pm.muxLogger
|
return pm.muxLogger, nil
|
||||||
} else if logMonitorId == "proxy" {
|
case "proxy":
|
||||||
logger = pm.proxyLogger
|
return pm.proxyLogger, nil
|
||||||
} else if logMonitorId == "upstream" {
|
case "upstream":
|
||||||
logger = pm.upstreamLogger
|
return pm.upstreamLogger, nil
|
||||||
} else {
|
default:
|
||||||
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
// search for a models specific logger using findModelInPath
|
||||||
}
|
// to handle model names with slashes (e.g., "author/model")
|
||||||
|
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
|
||||||
|
for _, group := range pm.processGroups {
|
||||||
|
if process, found := group.GetMember(name); found {
|
||||||
|
return process.Logger(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return logger, nil
|
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// selectEncoding chooses the best encoding based on Accept-Encoding header
|
||||||
|
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
|
||||||
|
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||||
|
if acceptEncoding == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||||
|
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||||
|
if enc == "br" {
|
||||||
|
return "br", ".br"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||||
|
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||||
|
if enc == "gzip" {
|
||||||
|
return "gzip", ".gz"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeCompressedFile serves a file with compression support.
|
||||||
|
// It checks for pre-compressed versions and serves them with proper headers.
|
||||||
|
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
|
||||||
|
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
|
||||||
|
|
||||||
|
// Try to serve compressed version if client supports it
|
||||||
|
if encoding != "" {
|
||||||
|
if cf, err := fs.Open(name + ext); err == nil {
|
||||||
|
defer cf.Close()
|
||||||
|
|
||||||
|
// Verify it's a regular file (not a directory)
|
||||||
|
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||||
|
// Set the content encoding header
|
||||||
|
w.Header().Set("Content-Encoding", encoding)
|
||||||
|
w.Header().Add("Vary", "Accept-Encoding")
|
||||||
|
|
||||||
|
// Get original file info for content type detection
|
||||||
|
origFile, err := fs.Open(name)
|
||||||
|
if err == nil {
|
||||||
|
origFile.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve the compressed file
|
||||||
|
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to serving the uncompressed file
|
||||||
|
file, err := fs.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
stat, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stat.IsDir() {
|
||||||
|
http.Error(w, "is a directory", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||||
|
}
|
||||||
@@ -0,0 +1,283 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServeCompressedFile_Brotli(t *testing.T) {
|
||||||
|
// Create test content
|
||||||
|
content := []byte("This is test content that should be compressed with brotli")
|
||||||
|
brContent := []byte("fake-brotli-compressed-data")
|
||||||
|
|
||||||
|
// Create a test filesystem
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"test.js": {Data: content, ModTime: time.Now()},
|
||||||
|
"test.js.br": {Data: brContent, ModTime: time.Now()},
|
||||||
|
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
|
||||||
|
}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "test.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that brotli is used (preferred over gzip)
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||||
|
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
|
||||||
|
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(body, brContent) {
|
||||||
|
t.Errorf("Expected brotli content, got %s", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeCompressedFile_Gzip(t *testing.T) {
|
||||||
|
// Create test content
|
||||||
|
content := []byte("This is test content that should be compressed with gzip")
|
||||||
|
gzContent := []byte("fake-gzip-compressed-data")
|
||||||
|
|
||||||
|
// Create a test filesystem without brotli
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"test.js": {Data: content, ModTime: time.Now()},
|
||||||
|
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
|
||||||
|
}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "test.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||||
|
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(body, gzContent) {
|
||||||
|
t.Errorf("Expected gzip content, got %s", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
|
||||||
|
// Create test content
|
||||||
|
content := []byte("This is uncompressed test content")
|
||||||
|
|
||||||
|
// Create a test filesystem without compressed versions
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"test.js": {Data: content, ModTime: time.Now()},
|
||||||
|
}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "test.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not have Content-Encoding header since we're serving uncompressed
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||||
|
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(body, content) {
|
||||||
|
t.Errorf("Expected original content, got %s", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
|
||||||
|
// Create test content
|
||||||
|
content := []byte("This is test content")
|
||||||
|
|
||||||
|
// Create a test filesystem with compressed versions
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"test.js": {Data: content, ModTime: time.Now()},
|
||||||
|
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
|
||||||
|
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
|
||||||
|
}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||||
|
// No Accept-Encoding header
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "test.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should serve uncompressed content
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||||
|
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(body, content) {
|
||||||
|
t.Errorf("Expected original content, got %s", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeCompressedFile_NotFound(t *testing.T) {
|
||||||
|
mapFS := fstest.MapFS{}
|
||||||
|
fs := http.FS(mapFS)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, "nonexistent.js")
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("Expected status 404, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectEncoding(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
acceptEncoding string
|
||||||
|
wantEncoding string
|
||||||
|
wantExt string
|
||||||
|
}{
|
||||||
|
{"br, gzip", "br", ".br"},
|
||||||
|
{"gzip, deflate", "gzip", ".gz"},
|
||||||
|
{"gzip", "gzip", ".gz"},
|
||||||
|
{"br", "br", ".br"},
|
||||||
|
{"", "", ""},
|
||||||
|
{"deflate", "", ""},
|
||||||
|
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
|
||||||
|
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
|
||||||
|
{"browser", "", ""},
|
||||||
|
{"compress, deflate", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
|
||||||
|
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
|
||||||
|
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
|
||||||
|
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with actual pre-compressed files from ui_dist
|
||||||
|
func TestServeCompressedFile_RealFiles(t *testing.T) {
|
||||||
|
// Check if ui_dist exists
|
||||||
|
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
|
||||||
|
t.Skip("ui_dist not found, skipping real file test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find a .js or .css file that has compressed versions
|
||||||
|
entries, err := os.ReadDir("./ui_dist/assets")
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Could not read ui_dist/assets: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var testFile string
|
||||||
|
for _, entry := range entries {
|
||||||
|
name := entry.Name()
|
||||||
|
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
|
||||||
|
// Check if compressed versions exist
|
||||||
|
base := strings.TrimSuffix(name, ".js")
|
||||||
|
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
|
||||||
|
testFile = "assets/" + name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if testFile == "" {
|
||||||
|
t.Skip("No suitable test file found with compressed versions")
|
||||||
|
}
|
||||||
|
|
||||||
|
fs := http.FS(os.DirFS("./ui_dist"))
|
||||||
|
|
||||||
|
// Test brotli
|
||||||
|
t.Run("brotli", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "br")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, testFile)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||||
|
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test gzip
|
||||||
|
t.Run("gzip", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||||
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ServeCompressedFile(fs, w, req, testFile)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||||
|
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's valid gzip
|
||||||
|
reader, err := gzip.NewReader(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected valid gzip content: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// Just read to verify it's valid
|
||||||
|
_, err = io.Copy(io.Discard, reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decompress gzip: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed ui_dist
|
||||||
|
var reactStaticFS embed.FS
|
||||||
|
|
||||||
|
// GetReactFS returns the embedded React filesystem
|
||||||
|
func GetReactFS() (http.FileSystem, error) {
|
||||||
|
subFS, err := fs.Sub(reactStaticFS, "ui_dist")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return http.FS(subFS), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReactIndexHTML returns the main index.html for the React app
|
||||||
|
func GetReactIndexHTML() ([]byte, error) {
|
||||||
|
return reactStaticFS.ReadFile("ui_dist/index.html")
|
||||||
|
}
|
||||||
@@ -0,0 +1,213 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
# This script installs llama-swap on Linux.
|
||||||
|
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
LLAMA_SWAP_DEFAULT_ADDRESS=${LLAMA_SWAP_DEFAULT_ADDRESS:-"127.0.0.1:8080"}
|
||||||
|
|
||||||
|
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||||
|
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||||
|
|
||||||
|
status() { echo ">>> $*" >&2; }
|
||||||
|
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||||
|
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||||
|
|
||||||
|
available() { command -v "$1" >/dev/null; }
|
||||||
|
require() {
|
||||||
|
_MISSING=''
|
||||||
|
for TOOL in "$@"; do
|
||||||
|
if ! available "$TOOL"; then
|
||||||
|
_MISSING="$_MISSING $TOOL"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "$_MISSING"
|
||||||
|
}
|
||||||
|
|
||||||
|
SUDO=
|
||||||
|
if [ "$(id -u)" -ne 0 ]; then
|
||||||
|
if ! available sudo; then
|
||||||
|
error "This script requires superuser permissions. Please re-run as root."
|
||||||
|
fi
|
||||||
|
|
||||||
|
SUDO="sudo"
|
||||||
|
fi
|
||||||
|
|
||||||
|
NEEDS=$(require tee tar python3 mktemp)
|
||||||
|
if [ -n "$NEEDS" ]; then
|
||||||
|
status "ERROR: The following tools are required but missing:"
|
||||||
|
for NEED in $NEEDS; do
|
||||||
|
echo " - $NEED"
|
||||||
|
done
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||||
|
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) ARCH="amd64" ;;
|
||||||
|
aarch64|arm64) ARCH="arm64" ;;
|
||||||
|
*) error "Unsupported architecture: $ARCH" ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
IS_WSL2=false
|
||||||
|
|
||||||
|
KERN=$(uname -r)
|
||||||
|
case "$KERN" in
|
||||||
|
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
|
||||||
|
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||||
|
*) ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
download_binary() {
|
||||||
|
ASSET_NAME="linux_$ARCH"
|
||||||
|
|
||||||
|
TMPDIR=$(mktemp -d)
|
||||||
|
trap 'rm -rf "${TMPDIR}"' EXIT INT TERM HUP
|
||||||
|
PYTHON_SCRIPT=$(cat <<EOF
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
ASSET_NAME = "${ASSET_NAME}"
|
||||||
|
|
||||||
|
with urllib.request.urlopen("https://api.github.com/repos/mostlygeek/llama-swap/releases/latest") as resp:
|
||||||
|
data = json.load(resp)
|
||||||
|
for asset in data.get("assets", []):
|
||||||
|
if ASSET_NAME in asset.get("name", ""):
|
||||||
|
url = asset["browser_download_url"]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("ERROR: Matching asset not found.", file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("Downloading:", url, file=sys.stderr)
|
||||||
|
output_path = os.path.join("${TMPDIR}", "llama-swap.tar.gz")
|
||||||
|
urllib.request.urlretrieve(url, output_path)
|
||||||
|
print(output_path)
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
|
||||||
|
TARFILE=$(python3 -c "$PYTHON_SCRIPT")
|
||||||
|
if [ ! -f "$TARFILE" ]; then
|
||||||
|
error "Failed to download binary."
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Extracting to /usr/local/bin"
|
||||||
|
$SUDO tar -xzf "$TARFILE" -C /usr/local/bin llama-swap
|
||||||
|
}
|
||||||
|
download_binary
|
||||||
|
|
||||||
|
configure_systemd() {
|
||||||
|
if ! id llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Creating llama-swap user..."
|
||||||
|
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
|
||||||
|
fi
|
||||||
|
if getent group render >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to render group..."
|
||||||
|
$SUDO usermod -a -G render llama-swap
|
||||||
|
fi
|
||||||
|
if getent group video >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to video group..."
|
||||||
|
$SUDO usermod -a -G video llama-swap
|
||||||
|
fi
|
||||||
|
if getent group docker >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to docker group..."
|
||||||
|
$SUDO usermod -a -G docker llama-swap
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Adding current user to llama-swap group..."
|
||||||
|
$SUDO usermod -a -G llama-swap "$(whoami)"
|
||||||
|
|
||||||
|
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||||
|
status "Creating default config.yaml..."
|
||||||
|
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
|
||||||
|
# default 15s likely to fail for default models due to downloading models
|
||||||
|
healthCheckTimeout: 60
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen2.5":
|
||||||
|
cmd: |
|
||||||
|
docker run
|
||||||
|
--rm
|
||||||
|
-p \${PORT}:8080
|
||||||
|
--name qwen2.5
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
|
cmdStop: docker stop qwen2.5
|
||||||
|
|
||||||
|
"smollm2":
|
||||||
|
cmd: |
|
||||||
|
docker run
|
||||||
|
--rm
|
||||||
|
-p \${PORT}:8080
|
||||||
|
--name smollm2
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
|
cmdStop: docker stop smollm2
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Creating llama-swap systemd service..."
|
||||||
|
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
|
||||||
|
[Unit]
|
||||||
|
Description=llama-swap
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
User=llama-swap
|
||||||
|
Group=llama-swap
|
||||||
|
|
||||||
|
# set this to match your environment
|
||||||
|
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config -listen ${LLAMA_SWAP_DEFAULT_ADDRESS}
|
||||||
|
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=3
|
||||||
|
StartLimitBurst=3
|
||||||
|
StartLimitInterval=30
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
EOF
|
||||||
|
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
|
||||||
|
case $SYSTEMCTL_RUNNING in
|
||||||
|
running|degraded)
|
||||||
|
status "Enabling and starting llama-swap service..."
|
||||||
|
$SUDO systemctl daemon-reload
|
||||||
|
$SUDO systemctl enable llama-swap
|
||||||
|
|
||||||
|
start_service() { $SUDO systemctl restart llama-swap; }
|
||||||
|
trap start_service EXIT
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
warning "systemd is not running"
|
||||||
|
if [ "$IS_WSL2" = true ]; then
|
||||||
|
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
if available systemctl; then
|
||||||
|
configure_systemd
|
||||||
|
fi
|
||||||
|
|
||||||
|
install_success() {
|
||||||
|
status "The llama-swap API is now available at http://${LLAMA_SWAP_DEFAULT_ADDRESS}"
|
||||||
|
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
|
||||||
|
status 'Install complete.'
|
||||||
|
}
|
||||||
|
|
||||||
|
# WSL2 only supports GPUs via nvidia passthrough
|
||||||
|
# so check for nvidia-smi to determine if GPU is available
|
||||||
|
if [ "$IS_WSL2" = true ]; then
|
||||||
|
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||||
|
status "Nvidia GPU detected."
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
install_success
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
# This script uninstalls llama-swap on Linux.
|
||||||
|
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||||
|
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||||
|
|
||||||
|
status() { echo ">>> $*" >&2; }
|
||||||
|
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||||
|
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||||
|
|
||||||
|
available() { command -v $1 >/dev/null; }
|
||||||
|
|
||||||
|
SUDO=
|
||||||
|
if [ "$(id -u)" -ne 0 ]; then
|
||||||
|
if ! available sudo; then
|
||||||
|
error "This script requires superuser permissions. Please re-run as root."
|
||||||
|
fi
|
||||||
|
|
||||||
|
SUDO="sudo"
|
||||||
|
fi
|
||||||
|
|
||||||
|
configure_systemd() {
|
||||||
|
status "Stopping llama-swap service..."
|
||||||
|
$SUDO systemctl stop llama-swap
|
||||||
|
|
||||||
|
status "Disabling llama-swap service..."
|
||||||
|
$SUDO systemctl disable llama-swap
|
||||||
|
}
|
||||||
|
if available systemctl; then
|
||||||
|
configure_systemd
|
||||||
|
fi
|
||||||
|
|
||||||
|
if available llama-swap; then
|
||||||
|
status "Removing llama-swap binary..."
|
||||||
|
$SUDO rm $(which llama-swap)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||||
|
while true; do
|
||||||
|
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
|
||||||
|
read answer
|
||||||
|
case "$answer" in
|
||||||
|
[Yy]* )
|
||||||
|
$SUDO rm -r /usr/share/llama-swap
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
[Nn]* | "" )
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
* )
|
||||||
|
echo "Invalid input. Please enter y or n."
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if id llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Removing llama-swap user..."
|
||||||
|
$SUDO userdel llama-swap
|
||||||
|
fi
|
||||||
|
|
||||||
|
if getent group llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Removing llama-swap group..."
|
||||||
|
$SUDO groupdel llama-swap
|
||||||
|
fi
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
node_modules
|
||||||
|
.vite
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<link rel="icon" type="image/png" href="/favicon-96x96.png" sizes="96x96" />
|
||||||
|
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
||||||
|
<link rel="shortcut icon" href="/favicon.ico" />
|
||||||
|
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
|
||||||
|
<link rel="manifest" href="/site.webmanifest" />
|
||||||
|
<title>llama-swap</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="app"></div>
|
||||||
|
<script type="module" src="/src/main.ts"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||