Compare commits
297 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 64e4c79fc3 | |||
| 19fb5f35e9 | |||
| b45102bde8 | |||
| 1688bdd1e9 | |||
| d33d51fa75 | |||
| e3bf065574 | |||
| 3e52144058 | |||
| d5e52d7d00 | |||
| 17e5263a76 | |||
| 8d6d949ec3 | |||
| b5fde8eb6d | |||
| 7eef5defb8 | |||
| bc01e6f539 | |||
| 0462e3dc3f | |||
| 7b20fc011b | |||
| 20738f3623 | |||
| cdea7d16bd | |||
| 5de387dbf9 | |||
| 6f8e7ccb57 | |||
| 4384315b44 | |||
| 6439ab1515 | |||
| f94226122c | |||
| 7493618fdc | |||
| 205efd40a1 | |||
| 14207f8492 | |||
| 4e850c2834 | |||
| 75fced579e | |||
| b73f367f22 | |||
| 8f2137c72b | |||
| 124007cc98 | |||
| eb5bfff0b0 | |||
| 3edb180c08 | |||
| 66d555e625 | |||
| 4f863fd9fc | |||
| 267c030457 | |||
| c19309fe7e | |||
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f | |||
| 37d74efc2d | |||
| 22e098ac8b | |||
| 9864f9f517 | |||
| 53b32f3601 | |||
| 565c44766d | |||
| e6a9e210ba | |||
| d3f329f924 | |||
| 98879b38c1 | |||
| 7b3b0f5eae | |||
| 021ccceef1 | |||
| f03871c50a | |||
| dc00d17abe | |||
| dea98733c3 | |||
| bccce5fa19 | |||
| c968da1b73 | |||
| a883d68d4f | |||
| b1dec8b735 | |||
| 06523d8c1e | |||
| 86e9b93c37 | |||
| 3acace810f | |||
| 554d29e87d | |||
| 3567b7df08 | |||
| 38738525c9 | |||
| c0fc858193 | |||
| b429349e8a | |||
| eab2efd7b5 | |||
| 6aedbe121a | |||
| b24467ab89 | |||
| 12b69fb718 | |||
| f91a8b2462 | |||
| a89b803d4a | |||
| f852689104 | |||
| e250e71e59 | |||
| d18dc26d01 | |||
| 8357714421 | |||
| c07179d6e2 | |||
| 7ff50631e0 | |||
| 9fc0431531 | |||
| 6516532568 | |||
| d58a8b85bf | |||
| caf9e98b1e | |||
| 539278343b | |||
| 00b738cd0f | |||
| 70930e4e91 | |||
| 1f6179110c | |||
| 216c40b951 | |||
| 9e3d491c85 | |||
| 1a84926505 | |||
| fc3bb716df | |||
| c36986fef6 | |||
| 558801db1a | |||
| b21dee27c1 | |||
| f58c8c8ec5 | |||
| 954e2dee73 | |||
| a533aec736 | |||
| 97b17fc47d | |||
| 2457840698 | |||
| 7f55494151 | |||
| 831a90d3b0 | |||
| 977f1856bb | |||
| 52b329f7bc | |||
| 57803fd3aa | |||
| c55d0cc842 | |||
| 7acbaf4712 | |||
| fcc5ad135a | |||
| 305e5a0031 | |||
| 04fc67354a | |||
| 4662cf7699 | |||
| 5dc6b3e6d9 | |||
| 74c69f39ef | |||
| a186318892 | |||
| c4e4d5e1e9 | |||
| 7985e94ba4 | |||
| 74556c3a36 | |||
| 5c381e4b30 | |||
| 10569ed546 | |||
| 5b10b3c23f | |||
| 45ea792a3a | |||
| 1bc2802353 | |||
| 701476c0c4 | |||
| 5c63e0066c | |||
| 8be5073c51 | |||
| 6307bd3205 | |||
| 558a72de17 | |||
| dc42cf366d | |||
| ba0a81937a | |||
| 574fdfabb4 | |||
| 5172cb2e12 | |||
| 5672cb03fd | |||
| 0f583163f7 | |||
| 7905fa9ea3 | |||
| bbaf172956 | |||
| fd50932dbc | |||
| 8c693e7fcf | |||
| 8f2af26a41 | |||
| 01d4838fb3 | |||
| accd65294b | |||
| 7472a25864 | |||
| cce0bc6aa1 | |||
| 36e25125e8 | |||
| 9a54273d15 | |||
| 87dce5f8f6 | |||
| 307e619521 | |||
| 6299c1b874 | |||
| a906cd459b | |||
| 78b2bc3dbc | |||
| 6a058e4191 | |||
| 1921e570d7 | |||
| c867a6c9a2 | |||
| 3bd1b23ce0 | |||
| 10606abf89 | |||
| fefd14903d | |||
| 717d64e336 | |||
| 285191e655 | |||
| 4236cec03a | |||
| 756193d0dd | |||
| a6b2e930d8 | |||
| 9e02c22ff8 | |||
| 0bdbf2fdc1 | |||
| 49035e2e8e | |||
| 9963ae18bf | |||
| 2ae48c713b | |||
| 54c519e365 | |||
| 3fce9ee0e9 | |||
| 5899ae7966 | |||
| 591a9cdf4d | |||
| 9a3c656738 | |||
| 75015f82ea | |||
| cc33b6c270 | |||
| 4fa12a429c | |||
| 2dc0ca0663 | |||
| a84098d3b4 | |||
| 4d02ccd26a | |||
| dfd47eeac4 | |||
| 1ac6499c08 | |||
| 25f3dc25e7 | |||
| 8422e4e6a1 | |||
| 02ee29d881 | |||
| b2a891f8f4 | |||
| 8d2b568897 | |||
| fb44cf4e08 | |||
| 02aee4e86d | |||
| f45896d395 | |||
| f7e46a359f | |||
| c260907415 | |||
| b83a5fa291 | |||
| 6e2ff28d59 | |||
| a8b81f2799 | |||
| f9ee7156dc | |||
| 2d00120781 | |||
| afc9aef058 | |||
| d7b390df74 | |||
| 5025c2f1f3 | |||
| e3a0b013c1 | |||
| f5763a94a0 | |||
| 8ada72eb57 | |||
| 2441b383d3 | |||
| 25f251699c | |||
| 7f37bcc6eb | |||
| 519c3a4d22 | |||
| 9dc4bcb46c | |||
| cb876c143b | |||
| bc652709a5 | |||
| 9548931258 | |||
| 5c5a5da664 | |||
| aa9ef59aa5 | |||
| 09e52c0500 | |||
| ca9063ffbe | |||
| 21d7973d11 | |||
| cc450e9c5f | |||
| 27465fe053 | |||
| 9667989727 | |||
| d9a1ddea0d | |||
| e7ab024ca0 | |||
| 448ccae959 | |||
| ec0348e431 | |||
| 06eda7f591 | |||
| 5fad24c16f | |||
| 8404244fab | |||
| 712cd01081 | |||
| 1f7aa359b1 | |||
| b138d6cf25 | |||
| fb7c808082 | |||
| a7e640b0f7 | |||
| 593604dfdc | |||
| b8f888f864 | |||
| 192b2ae621 | |||
| b7f8cb5094 | |||
| a23da6eb57 | |||
| 4c3aa40564 | |||
| 84e2c07a7e | |||
| 680af28bcc | |||
| d94db42ffe | |||
| 93cd83c55c | |||
| 5565fca3ac | |||
| d625ab8d92 | |||
| a3f82c140b | |||
| 5c97299e7b | |||
| 671c1a5a7b | |||
| 52c0196e0f | |||
| 3201a68a04 | |||
| 3ac94ad20e | |||
| 60355bf74a | |||
| 9b2ed244e2 | |||
| eeb72297f7 | |||
| eabfe70cc6 | |||
| 29cd98878d | |||
| b3d331da0d | |||
| 62275e078d | |||
| 88916059e1 | |||
| 082d5d0fc5 | |||
| 53338938bd | |||
| af653347ae | |||
| 1e25b44a06 | |||
| 0815bb4cc3 | |||
| 7187cfe52e | |||
| 24089d2d9c | |||
| ebabe55ff3 | |||
| 41a338297c | |||
| 7e3353efeb | |||
| 4ed58fb173 | |||
| f5a2be698d | |||
| f5e6ec3b7a | |||
| 3f462da146 | |||
| 48bd766536 | |||
| 8d319da4dd | |||
| be7c502448 | |||
| 92336f00bf | |||
| ed2a50d9a6 | |||
| 0acfdb9f78 | |||
| 96a8ea0241 | |||
| f20f2c9b7a | |||
| 7a97c38828 | |||
| 4885132565 | |||
| 8b46a0b7f1 | |||
| 1b6736ec6f | |||
| ddc1ce031e | |||
| 11d024bbaa | |||
| 43e23c16dc | |||
| f9c8e763ba | |||
| d7e1bb9f7c | |||
| ab93460a8b | |||
| 13d4552edc | |||
| 6667e307a2 | |||
| 7ac446e6a9 | |||
| eab9795bcc | |||
| 09bdd86b54 | |||
| 85cd74a51c | |||
| 314d2f2212 | |||
| fad25f3e11 | |||
| 2c3e3e27f7 | |||
| baeb0c4e7f | |||
| 2833517eef | |||
| abdc2bfdb3 | |||
| c3b834737f | |||
| 3c8e727b73 |
@@ -0,0 +1,22 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: true
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
issue_enrichment:
|
||||
planning:
|
||||
enabled: false
|
||||
@@ -0,0 +1,39 @@
|
||||
---
|
||||
name: Bug Report
|
||||
about: I found a defect
|
||||
title: ''
|
||||
labels: 'unconfirmed bug'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
> [!IMPORTANT]
|
||||
> If you have questions about llama-swap please post in the Q&A in Discussions. Use bug reports when you've found a defect and wish to discuss a fix.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Expected behaviour**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Operating system and version**
|
||||
|
||||
- OS: (linux, osx, windows, freebsd, etc)
|
||||
- GPUs: (list architecture)
|
||||
|
||||
**My Configuration**
|
||||
|
||||
```yaml
|
||||
# copy / paste your configuration here
|
||||
```
|
||||
|
||||
**Proxy Logs**
|
||||
|
||||
```
|
||||
# copy / paste from /logs
|
||||
```
|
||||
|
||||
**Upstream Logs**
|
||||
|
||||
```
|
||||
# copy/paste from /logs
|
||||
```
|
||||
@@ -0,0 +1,23 @@
|
||||
# https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "32 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
days-before-issue-stale: 14
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -0,0 +1,41 @@
|
||||
name: Validate JSON Schema
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "config-schema.json"
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
validate-schema:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Validate JSON Schema
|
||||
run: |
|
||||
# Check if the file is valid JSON
|
||||
if ! jq empty config-schema.json 2>/dev/null; then
|
||||
echo "Error: config-schema.json is not valid JSON"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate that it's a valid JSON Schema
|
||||
# Check for required $schema field
|
||||
if ! jq -e '."$schema"' config-schema.json > /dev/null; then
|
||||
echo "Warning: config-schema.json should have a \$schema field"
|
||||
fi
|
||||
|
||||
# Check that it has either properties or definitions
|
||||
if ! jq -e '.properties or .definitions or ."$defs"' config-schema.json > /dev/null; then
|
||||
echo "Warning: JSON Schema should contain properties, definitions, or \$defs"
|
||||
fi
|
||||
|
||||
echo "✓ config-schema.json is valid"
|
||||
@@ -0,0 +1,73 @@
|
||||
name: Build Containers
|
||||
|
||||
on:
|
||||
# time has no specific meaning, trying to time it after
|
||||
# the llama.cpp daily packages are published
|
||||
# https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml
|
||||
schedule:
|
||||
- cron: "37 5 * * *"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on workflow file changes (without pushing)
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/containers.yml'
|
||||
- 'docker/build-container.sh'
|
||||
- 'docker/*.Containerfile'
|
||||
|
||||
# grant permissions on GITHUB_TOKEN to publish packages
|
||||
# ref: https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#publishing-a-package-using-an-action
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa, rocm]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
if: matrix.platform == 'rocm'
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
delete-untagged-containers:
|
||||
needs: build-and-push
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/delete-package-versions@v5
|
||||
with:
|
||||
package-name: 'llama-swap'
|
||||
package-type: 'container'
|
||||
delete-only-untagged-versions: 'true'
|
||||
@@ -0,0 +1,66 @@
|
||||
name: Windows CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci-windows.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: make simple-responder-windows
|
||||
|
||||
- name: Save Simple Responder
|
||||
# nothing new to save ... skip this step
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
id: save-simple-responder
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('cmd/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
shell: bash
|
||||
run: make test-all
|
||||
@@ -0,0 +1,70 @@
|
||||
name: Linux CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
# only run when backend source changes
|
||||
# cmd/ is excluded because it contains utilities without tests
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- '**/*.go'
|
||||
- '!cmd/**'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Makefile'
|
||||
- '.github/workflows/go-ci.yml'
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
# Only run in this linux based runner
|
||||
- name: Check Formatting
|
||||
run: |
|
||||
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||
gofmt -l . | grep -v 'event/.*_test.go'
|
||||
exit 1
|
||||
fi
|
||||
# cache simple-responder to save the build time
|
||||
- name: Restore Simple Responder
|
||||
id: restore-simple-responder
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
# necessary for testing proxy/Process swapping
|
||||
- name: Create simple-responder
|
||||
run: make simple-responder
|
||||
|
||||
- name: Save Simple Responder
|
||||
# nothing new to save ... skip this step
|
||||
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||
id: save-simple-responder
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: ./build
|
||||
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||
|
||||
- name: Test all
|
||||
run: make test-all
|
||||
@@ -3,7 +3,14 @@ name: goreleaser
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag version to release (e.g. v144)"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -12,22 +19,56 @@ jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
-
|
||||
name: Set up Go
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
-
|
||||
name: Run GoReleaser
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "24"
|
||||
- name: Install dependencies and build UI
|
||||
run: |
|
||||
cd ui-svelte
|
||||
npm ci
|
||||
npm run build
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
# either 'goreleaser' (default) or 'goreleaser-pro'
|
||||
distribution: goreleaser
|
||||
# 'latest', 'nightly', or a semver
|
||||
version: '~> v2'
|
||||
version: "~> v2"
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
trigger-tap-update:
|
||||
runs-on: ubuntu-latest
|
||||
needs: goreleaser
|
||||
steps:
|
||||
- name: "Resolve tag to dispatch"
|
||||
id: tag
|
||||
run: |
|
||||
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: "Trigger tap repository update"
|
||||
uses: peter-evans/repository-dispatch@v2
|
||||
with:
|
||||
token: ${{ secrets.TAP_REPO_PAT }}
|
||||
repository: mostlygeek/homebrew-llama-swap
|
||||
event-type: new-release
|
||||
client-payload: |
|
||||
{
|
||||
"release": {
|
||||
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
name: UI Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
paths:
|
||||
- 'ui-svelte/**'
|
||||
- '.github/workflows/ui-tests.yml'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
run-tests:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ui-svelte
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '24'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: ui-svelte/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Type check
|
||||
run: npm run check
|
||||
|
||||
- name: Run tests
|
||||
run: npm test
|
||||
@@ -4,3 +4,4 @@ build/
|
||||
dist/
|
||||
.vscode
|
||||
.DS_Store
|
||||
.dev/
|
||||
|
||||
+22
-1
@@ -6,6 +6,27 @@ builds:
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- freebsd
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm64
|
||||
ignore:
|
||||
- goos: freebsd
|
||||
goarch: arm64
|
||||
- goos: windows
|
||||
goarch: arm64
|
||||
|
||||
archives:
|
||||
- id: default
|
||||
formats:
|
||||
- tar.gz
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
builds_info:
|
||||
group: root
|
||||
owner: root
|
||||
format_overrides:
|
||||
# use zip format for windows
|
||||
- goos: windows
|
||||
formats:
|
||||
- zip
|
||||
@@ -0,0 +1,50 @@
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and svelt5 for UI (located in ui/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use the github CLI `gh` to create pull requests and work with github
|
||||
- Rules for creating pull requests:
|
||||
- keep them short and focused on changes.
|
||||
- never include a test plan
|
||||
- write the summary using the same style rules as commit message
|
||||
|
||||
## Testing
|
||||
|
||||
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||
|
||||
### Commit message example format:
|
||||
|
||||
```
|
||||
proxy: add new feature
|
||||
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
- key change 1
|
||||
- key change 2
|
||||
- key change 3
|
||||
|
||||
fixes #123
|
||||
```
|
||||
|
||||
## Code Reviews
|
||||
|
||||
- use three levels High, Medium, Low severity
|
||||
- label each discovered issue with a label like H1, M2, L3 respectively
|
||||
- High severity are must fix issues (security, race conditions, critical bugs)
|
||||
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
|
||||
- Low severity are nice to have changes and nits
|
||||
- Include a suggestion with each discovered item
|
||||
- Limit your code review to three items with the highest priority first
|
||||
- Double check your discovered items and recommended remediations
|
||||
@@ -19,27 +19,54 @@ all: mac linux simple-responder
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
|
||||
test:
|
||||
go test -short -v ./proxy
|
||||
proxy/ui_dist/placeholder.txt:
|
||||
mkdir -p proxy/ui_dist
|
||||
touch $@
|
||||
|
||||
test-all:
|
||||
go test -v ./proxy
|
||||
# use cached test results while developing
|
||||
test-dev: proxy/ui_dist/placeholder.txt
|
||||
go test -short ./proxy/...
|
||||
staticcheck ./proxy/... || true
|
||||
|
||||
test: proxy/ui_dist/placeholder.txt
|
||||
go test -short -count=1 ./proxy/...
|
||||
|
||||
# for CI - full test (takes longer)
|
||||
test-all: proxy/ui_dist/placeholder.txt
|
||||
go test -race -count=1 ./proxy/...
|
||||
|
||||
ui/node_modules:
|
||||
cd ui-svelte && npm install
|
||||
|
||||
# build react UI
|
||||
ui: ui/node_modules
|
||||
cd ui-svelte && npm run build
|
||||
|
||||
# Build OSX binary
|
||||
mac:
|
||||
mac: ui
|
||||
@echo "Building Mac binary..."
|
||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||
|
||||
# Build Linux binary
|
||||
linux:
|
||||
linux: ui
|
||||
@echo "Building Linux binary..."
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||
|
||||
# Build Windows binary
|
||||
windows: ui
|
||||
@echo "Building Windows binary..."
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||
|
||||
# for testing proxy.Process
|
||||
simple-responder:
|
||||
@echo "Building simple responder"
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 cmd/simple-responder/simple-responder.go
|
||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 cmd/simple-responder/simple-responder.go
|
||||
|
||||
simple-responder-windows:
|
||||
@echo "Building simple responder for windows"
|
||||
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe cmd/simple-responder/simple-responder.go
|
||||
|
||||
# Ensure build directory exists
|
||||
$(BUILD_DIR):
|
||||
@@ -59,5 +86,11 @@ release:
|
||||
echo "tagging new version: $$new_tag"; \
|
||||
git tag "$$new_tag";
|
||||
|
||||
GOOS ?= $(shell go env GOOS 2>/dev/null || echo linux)
|
||||
GOARCH ?= $(shell go env GOARCH 2>/dev/null || echo amd64)
|
||||
wol-proxy: $(BUILD_DIR)
|
||||
@echo "Building wol-proxy"
|
||||
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
||||
|
||||
# Phony targets
|
||||
.PHONY: all clean osx linux
|
||||
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
|
||||
|
||||
@@ -1,165 +1,247 @@
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
# llama-swap
|
||||
|
||||

|
||||
Run multiple LLM models on your machine and hot-swap between them as needed. llama-swap works with any OpenAI API-compatible server, giving you the flexibility to switch models without restarting your applications.
|
||||
|
||||
# Introduction
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
|
||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). Download a pre-built [release](https://github.com/mostlygeek/llama-swap/releases) or built it yourself from source with `make clean all`.
|
||||
|
||||
## How does it work?
|
||||
When a request is made to an OpenAI compatible endpoints, lama-swap will extract the `model` value load the appropriate server configuration to serve it. If a server is already running it will stop it and start a new one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported. For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||
Built in Go for performance and simplicity, llama-swap has zero dependencies and is incredibly easy to set up. Get started in minutes - just one binary and one configuration file.
|
||||
|
||||
## Features:
|
||||
|
||||
- ✅ Easy to deploy: single binary with no dependencies
|
||||
- ✅ Easy to config: single yaml file
|
||||
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
|
||||
- ✅ On-demand model switching
|
||||
- ✅ Full control over server settings per model
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, stable-diffusion.cpp, etc.)
|
||||
- future proof, upgrade your inference servers at any time.
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/responses`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/audio/speech`
|
||||
- ✅ Multiple GPU support
|
||||
- ✅ Run multiple models at once with `profiles`
|
||||
- ✅ Remote log monitoring at `/log`
|
||||
- ✅ Automatic unloading of models from GPUs after timeout
|
||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- `v1/audio/voices`
|
||||
- `v1/images/generations`
|
||||
- `v1/images/edits`
|
||||
- ✅ Anthropic API supported endpoints:
|
||||
- `v1/messages`
|
||||
- `v1/messages/count_tokens`
|
||||
- ✅ llama-server (llama.cpp) supported endpoints
|
||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||
- `/infill` - for code infilling
|
||||
- `/completion` - for completion endpoint
|
||||
- ✅ llama-swap API
|
||||
- `/ui` - web UI
|
||||
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- `/log` - remote log monitoring
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||
- ✅ Customizable
|
||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- Automatic unloading of models after timeout by setting a `ttl`
|
||||
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
||||
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||
|
||||
## config.yaml
|
||||
### Web UI
|
||||
|
||||
llama-swap's configuration is purposefully simple.
|
||||
llama-swap includes a real time web interface for monitoring logs and controlling models:
|
||||
|
||||
```yaml
|
||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||
# Default (and minimum) is 15 seconds
|
||||
healthCheckTimeout: 60
|
||||
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
|
||||
|
||||
# Write HTTP logs (useful for troubleshooting), defaults to false
|
||||
logRequests: true
|
||||
The Activity Page shows recent requests:
|
||||
|
||||
# define valid model values and the upstream server start
|
||||
models:
|
||||
"llama":
|
||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||
|
||||
# where to reach the server started by cmd, make sure the ports match
|
||||
proxy: http://127.0.0.1:8999
|
||||
## Installation
|
||||
|
||||
# aliases names to use this model for
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
llama-swap can be installed in multiple ways
|
||||
|
||||
# check this path for an HTTP 200 OK before serving requests
|
||||
# default: /health to match llama.cpp
|
||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
||||
# until the model is ready
|
||||
checkEndpoint: /custom-endpoint
|
||||
1. Docker
|
||||
2. Homebrew (OSX and Linux)
|
||||
3. WinGet
|
||||
4. From release binaries
|
||||
5. From source
|
||||
|
||||
# automatically unload the model after this many seconds
|
||||
# ttl values must be a value greater than 0
|
||||
# default: 0 = never unload model
|
||||
ttl: 60
|
||||
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||
|
||||
"qwen":
|
||||
# environment variables to pass to the command
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
|
||||
The stable-diffusion.cpp server is also included for the musa and vulkan platforms.
|
||||
|
||||
# multiline for readability
|
||||
cmd: >
|
||||
llama-server --port 8999
|
||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
```shell
|
||||
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
|
||||
# unlisted models do not show up in /v1/models or /upstream lists
|
||||
# but they can still be requested as normal
|
||||
"qwen-unlisted":
|
||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
unlisted: true
|
||||
# run with a custom configuration and models directory
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda
|
||||
|
||||
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||
#
|
||||
# Tips:
|
||||
# - each model must be listening on a unique address and port
|
||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
||||
# - the profile will load and unload all models in the profile at the same time
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
# configuration hot reload supported with a
|
||||
# directory volume mount
|
||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||
-v /path/to/config:/config \
|
||||
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
|
||||
```
|
||||
|
||||
### Advanced Examples
|
||||
<details>
|
||||
<summary>
|
||||
more examples
|
||||
</summary>
|
||||
|
||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
||||
```shell
|
||||
# pull latest images per platform
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cpu
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cuda
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:vulkan
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:intel
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:musa
|
||||
|
||||
### Installation
|
||||
# tagged llama-swap, platform and llama-server version images
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
|
||||
|
||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||
* _Note: Windows currently untested._
|
||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
||||
# non-root cuda
|
||||
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Homebrew Install (macOS/Linux)
|
||||
|
||||
```shell
|
||||
brew tap mostlygeek/llama-swap
|
||||
brew install llama-swap
|
||||
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||
```
|
||||
|
||||
### WinGet Install (Windows)
|
||||
|
||||
> [!NOTE]
|
||||
> WinGet is maintained by community contributor [Dvd-Znf](https://github.com/Dvd-Znf) ([#327](https://github.com/mostlygeek/llama-swap/issues/327)). It is not an official part of llama-swap.
|
||||
|
||||
```shell
|
||||
# install
|
||||
C:\> winget install llama-swap
|
||||
|
||||
# upgrade
|
||||
C:\> winget upgrade llama-swap
|
||||
```
|
||||
|
||||
### Pre-built Binaries
|
||||
|
||||
Binaries are available on the [release](https://github.com/mostlygeek/llama-swap/releases) page for Linux, Mac, Windows and FreeBSD.
|
||||
|
||||
### Building from source
|
||||
|
||||
1. Install golang for your system
|
||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
||||
1. Building requires Go and Node.js (for UI).
|
||||
1. `git clone https://github.com/mostlygeek/llama-swap.git`
|
||||
1. `make clean all`
|
||||
1. Binaries will be in `build/` subdirectory
|
||||
1. look in the `build/` subdirectory for the llama-swap binary
|
||||
|
||||
## Monitoring Logs
|
||||
## Configuration
|
||||
|
||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
||||
|
||||
Of course, CLI access is also supported:
|
||||
```yaml
|
||||
# minimum viable config.yaml
|
||||
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
|
||||
```
|
||||
# sends up to the last 10KB of logs
|
||||
curl http://host/logs'
|
||||
|
||||
# streams logs
|
||||
curl -Ns 'http://host/logs/stream'
|
||||
That's all you need to get started:
|
||||
|
||||
1. `models` - holds all model configurations
|
||||
2. `model1` - the ID used in API calls
|
||||
3. `cmd` - the command to run to start the server.
|
||||
4. `${PORT}` - an automatically assigned port number
|
||||
|
||||
Almost all configuration settings are optional and can be added one step at a time:
|
||||
|
||||
- Advanced features
|
||||
- `groups` to run multiple models at once
|
||||
- `hooks` to run things on startup
|
||||
- `macros` reusable snippets
|
||||
- Model customization
|
||||
- `ttl` to automatically unload models
|
||||
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||
- `env` to pass custom environment variables to inference servers
|
||||
- `cmdStop` gracefully stop Docker/Podman containers
|
||||
- `useModelName` to override model names sent to upstream servers
|
||||
- `${PORT}` automatic port variables for dynamic port assignment
|
||||
- `filters` rewrite parts of requests before sending to the upstream server
|
||||
|
||||
See the [configuration documentation](docs/configuration.md) for all options.
|
||||
|
||||
## How does llama-swap work?
|
||||
|
||||
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to handle the request correctly.
|
||||
|
||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||
|
||||
## Reverse Proxy Configuration (nginx)
|
||||
|
||||
If you deploy llama-swap behind nginx, disable response buffering for streaming endpoints. By default, nginx buffers responses which breaks Server‑Sent Events (SSE) and streaming chat completion. ([#236](https://github.com/mostlygeek/llama-swap/issues/236))
|
||||
|
||||
Recommended nginx configuration snippets:
|
||||
|
||||
```nginx
|
||||
# SSE for UI events/logs
|
||||
location /api/events {
|
||||
proxy_pass http://your-llama-swap-backend;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Streaming chat completions (stream=true)
|
||||
location /v1/chat/completions {
|
||||
proxy_pass http://your-llama-swap-backend;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
```
|
||||
|
||||
As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. However, explicitly disabling `proxy_buffering` at your reverse proxy is still recommended for reliable streaming behavior.
|
||||
|
||||
## Monitoring Logs on the CLI
|
||||
|
||||
```sh
|
||||
# sends up to the last 10KB of logs
|
||||
$ curl http://host/logs
|
||||
|
||||
# streams combined logs
|
||||
curl -Ns http://host/logs/stream
|
||||
|
||||
# stream llama-swap's proxy status logs
|
||||
curl -Ns http://host/logs/stream/proxy
|
||||
|
||||
# stream logs from upstream processes that llama-swap loads
|
||||
curl -Ns http://host/logs/stream/upstream
|
||||
|
||||
# stream logs only from a specific model
|
||||
curl -Ns http://host/logs/stream/{model_id}
|
||||
|
||||
# stream and filter logs with linux pipes
|
||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||
|
||||
# skips history and just streams new log entries
|
||||
# appending ?no-history will disable sending buffered history first
|
||||
curl -Ns 'http://host/logs/stream?no-history'
|
||||
```
|
||||
|
||||
## Systemd Unit Files
|
||||
## Do I need to use llama.cpp's server (llama-server)?
|
||||
|
||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
||||
Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported.
|
||||
|
||||
`/etc/systemd/system/llama-swap.service`
|
||||
```
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
After=network.target
|
||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals for proper shutdown.
|
||||
|
||||
[Service]
|
||||
User=nobody
|
||||
## Star History
|
||||
|
||||
# set this to match your environment
|
||||
ExecStart=/path/to/llama-swap --config /path/to/llama-swap.config.yml
|
||||
> [!NOTE]
|
||||
> ⭐️ Star this project to help others discover it!
|
||||
|
||||
Restart=on-failure
|
||||
RestartSec=3
|
||||
StartLimitBurst=3
|
||||
StartLimitInterval=30
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# Replace ring.Ring with Efficient Circular Byte Buffer
|
||||
|
||||
## Overview
|
||||
|
||||
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
|
||||
|
||||
## Current Issues
|
||||
|
||||
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
|
||||
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
|
||||
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
|
||||
4. Linked list structure has poor cache locality and pointer overhead
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### New CircularBuffer Type
|
||||
|
||||
Create a simple circular byte buffer with:
|
||||
- Single pre-allocated `[]byte` of fixed capacity (10KB)
|
||||
- `head` and `size` integers to track write position and data length
|
||||
- No per-write allocations
|
||||
|
||||
### API Requirements
|
||||
|
||||
The new buffer must support:
|
||||
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
|
||||
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
|
||||
|
||||
### Implementation Details
|
||||
|
||||
```go
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
```
|
||||
|
||||
**Write logic:**
|
||||
- If `len(p) >= capacity`: just keep the last `capacity` bytes
|
||||
- Otherwise: write bytes at `head`, wrapping around if needed
|
||||
- Update `head` and `size` accordingly
|
||||
- Data is copied into the internal buffer (not stored by reference)
|
||||
|
||||
**GetHistory logic:**
|
||||
- Calculate start position: `(head - size + cap) % cap`
|
||||
- If not wrapped: single slice copy
|
||||
- If wrapped: two copies (end of buffer + beginning)
|
||||
- Returns a **new slice** (copy), not a view into internal buffer
|
||||
|
||||
### Immutability Guarantees (must preserve)
|
||||
|
||||
Per existing tests:
|
||||
1. Modifying input `[]byte` after `Write()` must not affect stored data
|
||||
2. `GetHistory()` returns independent copy - modifications don't affect buffer
|
||||
|
||||
## Files to Modify
|
||||
|
||||
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
|
||||
|
||||
## Testing Plan
|
||||
|
||||
Existing tests in `logMonitor_test.go` should continue to pass:
|
||||
- `TestLogMonitor` - Basic write/read and subscriber notification
|
||||
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
|
||||
- `TestWrite_LogTimeFormat` - Timestamp formatting
|
||||
|
||||
Add new tests:
|
||||
- Test buffer wrap-around behavior
|
||||
- Test large writes that exceed buffer capacity
|
||||
- Test exact capacity boundary conditions
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Create `circularBuffer` struct in `logMonitor.go`
|
||||
- [ ] Implement `Write()` method for circular buffer
|
||||
- [ ] Implement `GetHistory()` method for circular buffer
|
||||
- [ ] Update `LogMonitor` struct to use new buffer
|
||||
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
|
||||
- [ ] Update `LogMonitor.Write()` to use new buffer
|
||||
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
|
||||
- [ ] Remove `"container/ring"` import
|
||||
- [ ] Run `make test-dev` to verify existing tests pass
|
||||
- [ ] Add wrap-around test case
|
||||
- [ ] Run `make test-all` for final validation
|
||||
@@ -0,0 +1,292 @@
|
||||
# Add Model Metadata Support with Typed Macros
|
||||
|
||||
## Overview
|
||||
|
||||
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
|
||||
|
||||
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### 1. Enhanced Macro System
|
||||
|
||||
**Current State:**
|
||||
|
||||
- Macros are defined as `map[string]string` at both global and model levels
|
||||
- Only string substitution is supported
|
||||
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Change `MacroList` type from `map[string]string` to `map[string]any`
|
||||
- Support scalar types: `string`, `int`, `float64`, `bool`
|
||||
- Implement type-preserving macro substitution:
|
||||
- Direct macro usage (`key: ${macro}`) preserves the macro's type
|
||||
- Interpolated usage (`key: "text ${macro}"`) converts to string
|
||||
- Add validation to ensure macro values are scalar types only
|
||||
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
|
||||
|
||||
**Implementation Details:**
|
||||
|
||||
- Create a generic helper function to perform macro substitution that:
|
||||
- Takes a value of type `any`
|
||||
- Recursively processes maps, slices, and scalar values
|
||||
- Replaces `${macro_name}` patterns with macro values
|
||||
- Preserves types for direct substitution
|
||||
- Converts to strings for interpolated substitution
|
||||
- Update `validateMacro()` function to accept `any` type and validate scalar types
|
||||
- Maintain backward compatibility with existing string-only macros
|
||||
|
||||
### 2. Metadata Field in ModelConfig
|
||||
|
||||
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Add `Metadata map[string]any` field to `ModelConfig` struct
|
||||
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
|
||||
- Apply macro substitution to metadata values during config loading
|
||||
|
||||
**Schema Requirements:**
|
||||
|
||||
- Metadata is optional (default: empty/nil map)
|
||||
- Supports nested structures (objects within objects, arrays, etc.)
|
||||
- All string values within metadata undergo macro substitution
|
||||
- Type preservation rules apply as described above
|
||||
|
||||
### 3. Macro Substitution in Metadata
|
||||
|
||||
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
|
||||
|
||||
**Process Flow:**
|
||||
|
||||
1. After loading YAML configuration
|
||||
2. After model-level and global macro merging
|
||||
3. Apply macro substitution to `ModelConfig.Metadata` field
|
||||
4. Use the same merged macros available to `cmd`, `proxy`, etc.
|
||||
5. Process recursively through all nested structures
|
||||
|
||||
**Substitution Rules:**
|
||||
|
||||
- `port: ${PORT}` → keeps integer type from PORT macro
|
||||
- `temperature: ${temp}` → keeps float type from temp macro
|
||||
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
|
||||
- Arrays and nested objects are processed recursively
|
||||
- Unknown macros should cause configuration load error (consistent with existing behavior)
|
||||
|
||||
### 4. API Response Updates
|
||||
|
||||
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
|
||||
|
||||
**Current Behavior:**
|
||||
|
||||
- Returns model records with: `id`, `object`, `created`, `owned_by`
|
||||
- Optionally includes: `name`, `description`
|
||||
|
||||
**Required Changes:**
|
||||
|
||||
- Add metadata to each model record under the key `llamaswap_meta`
|
||||
- Only include `llamaswap_meta` if metadata is non-empty
|
||||
- Preserve all types when marshaling to JSON
|
||||
- Maintain existing sorting by model ID
|
||||
|
||||
**Example Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "llama",
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "llama-swap",
|
||||
"name": "llama 3.1 8B",
|
||||
"description": "A small but capable model",
|
||||
"llamaswap_meta": {
|
||||
"port": 10001,
|
||||
"temperature": 0.7,
|
||||
"note": "The llama is running on port 10001 temp=0.7, context=16384",
|
||||
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
|
||||
"an_obj": {
|
||||
"a": "1",
|
||||
"b": 2,
|
||||
"c": [0.7, false, "model: llama"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Validation and Error Handling
|
||||
|
||||
**Macro Validation:**
|
||||
|
||||
- Extend `validateMacro()` to accept values of type `any`
|
||||
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
|
||||
- Reject complex types (maps, slices, structs) as macro values
|
||||
- Maintain existing validation for macro names and lengths
|
||||
|
||||
**Configuration Loading:**
|
||||
|
||||
- Fail fast if unknown macros are found in metadata
|
||||
- Provide clear error messages indicating which model and field contains errors
|
||||
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
|
||||
|
||||
## Testing Plan
|
||||
|
||||
### Test 1: Model-Level Macros with Different Types
|
||||
|
||||
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Define model with macros of each scalar type
|
||||
- Verify metadata correctly substitutes and preserves types
|
||||
- Test direct substitution (`port: ${PORT}`)
|
||||
- Test string interpolation (`note: "Port is ${PORT}"`)
|
||||
- Verify nested objects and arrays work correctly
|
||||
|
||||
### Test 2: Global and Model Macro Precedence
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Define same macro at global and model level with different types
|
||||
- Verify model-level macro takes precedence
|
||||
- Test metadata uses correct macro value
|
||||
- Verify type is preserved from the winning macro
|
||||
|
||||
### Test 3: Macro Validation
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Test that complex types (maps, arrays) are rejected as macro values
|
||||
- Verify error message includes: macro name and type that was rejected
|
||||
- Test that scalar types (string, int, float, bool) are accepted
|
||||
- Each type should load without error
|
||||
- Test macro name validation still works with `any` types
|
||||
- Invalid characters, reserved names, length limits should still be enforced
|
||||
|
||||
### Test 4: Metadata in API Response
|
||||
|
||||
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||
|
||||
**Existing Test:** `TestProxyManager_ListModelsHandler`
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Model with metadata → verify `llamaswap_meta` key appears
|
||||
- Model without metadata → verify `llamaswap_meta` key is absent
|
||||
- Verify all types are correctly marshaled to JSON
|
||||
- Verify nested structures are preserved
|
||||
- Verify macro substitution has occurred before serialization
|
||||
|
||||
### Test 5: Unknown Macros in Metadata
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Use undefined macro in metadata
|
||||
- Verify configuration loading fails with clear error
|
||||
- Error should indicate model name and that macro is undefined
|
||||
|
||||
### Test 6: Recursive Substitution
|
||||
|
||||
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
- Metadata with deeply nested structures
|
||||
- Arrays containing objects with macros
|
||||
- Objects containing arrays with macros
|
||||
- Mixed string interpolation and direct substitution at various nesting levels
|
||||
|
||||
## Checklist
|
||||
|
||||
### Configuration Schema Changes
|
||||
|
||||
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
|
||||
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
|
||||
- [x] Update `validateMacro()` function signature to accept `any` type for values
|
||||
- [x] Add validation logic to ensure macro values are scalar types only
|
||||
|
||||
### Macro Substitution Logic
|
||||
|
||||
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
|
||||
- [x] Implement type-preserving direct substitution logic
|
||||
- [x] Implement string interpolation with type conversion
|
||||
- [x] Handle maps: recursively process all values
|
||||
- [x] Handle slices: recursively process all elements
|
||||
- [x] Handle scalar types: perform string-based macro substitution if value is string
|
||||
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
|
||||
- [x] Update existing macro substitution calls to use merged macros with correct types
|
||||
|
||||
### API Response Changes
|
||||
|
||||
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
|
||||
- [x] Add `llamaswap_meta` field to model records when metadata exists
|
||||
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
|
||||
- [x] Verify JSON marshaling preserves all types correctly
|
||||
|
||||
### Testing - Config Package
|
||||
|
||||
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||
|
||||
### Testing - Model Config Package
|
||||
|
||||
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
|
||||
- [x] Test metadata with various scalar types
|
||||
- [x] Test metadata with nested objects and arrays
|
||||
|
||||
### Testing - Proxy Manager
|
||||
|
||||
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||
- [x] Add test case for model with metadata
|
||||
- [x] Add test case for model without metadata
|
||||
- [x] Verify `llamaswap_meta` key presence/absence
|
||||
- [x] Verify type preservation in JSON output
|
||||
- [x] Verify macro substitution has occurred
|
||||
|
||||
### Documentation
|
||||
|
||||
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
|
||||
- [x] No additional documentation needed per project instructions
|
||||
|
||||
## Known Issues and Considerations
|
||||
|
||||
### Inconsistencies
|
||||
|
||||
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
|
||||
|
||||
### Design Decisions
|
||||
|
||||
1. **Why `llamaswap_meta` instead of merging into record?**
|
||||
|
||||
- Avoids potential collisions with OpenAI API standard fields
|
||||
- Makes it clear this is llama-swap specific metadata
|
||||
- Easier for clients to distinguish standard vs. custom fields
|
||||
|
||||
2. **Why support nested structures?**
|
||||
|
||||
- Provides maximum flexibility for users
|
||||
- Aligns with the schemaless design principle
|
||||
- Example config already demonstrates this capability
|
||||
|
||||
3. **Why validate macro types?**
|
||||
- Prevents confusing behavior (e.g., substituting a map)
|
||||
- Makes configuration errors explicit at load time
|
||||
- Simpler implementation and testing
|
||||
@@ -0,0 +1,397 @@
|
||||
# Improve macro-in-macro support
|
||||
|
||||
**Status: COMPLETED ✅**
|
||||
|
||||
## Title
|
||||
|
||||
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
|
||||
|
||||
## Overview
|
||||
|
||||
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
|
||||
|
||||
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
|
||||
|
||||
**Outcomes:**
|
||||
- Macros can reference other macros defined earlier in the config
|
||||
- Macro substitution is deterministic and order-dependent
|
||||
- Single-pass substitution prevents circular dependencies
|
||||
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
|
||||
- All existing tests pass
|
||||
- New tests validate substitution order and self-reference detection
|
||||
|
||||
## Design Requirements
|
||||
|
||||
### 1. YAML Parsing Strategy
|
||||
- **Continue using:** `gopkg.in/yaml.v3` (current library)
|
||||
- **Use:** `yaml.Node` for ordered parsing of macros
|
||||
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
|
||||
|
||||
### 2. Data Structure Changes
|
||||
|
||||
#### Current Implementation (config.go:19)
|
||||
```go
|
||||
type MacroList map[string]any
|
||||
```
|
||||
|
||||
#### New Implementation
|
||||
```go
|
||||
type MacroList []MacroEntry
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
```
|
||||
|
||||
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
|
||||
|
||||
### 3. Macro Substitution Order Rules
|
||||
|
||||
The substitution must follow this hierarchy (from most specific to least):
|
||||
|
||||
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
|
||||
2. **Model-level macros** (middle): Defined in specific model config, overrides global
|
||||
3. **Global macros** (first): Defined at config root level
|
||||
|
||||
Within each level, macros are substituted in **reverse definition order** (LIFO):
|
||||
- The last macro defined is substituted first
|
||||
- This allows later macros to reference earlier ones
|
||||
- Single-pass substitution prevents circular dependencies
|
||||
|
||||
### 4. Macro Reference Rules
|
||||
|
||||
**Allowed:**
|
||||
- Macro can reference any macro defined **before** it (earlier in the file)
|
||||
- Model macros can reference global macros
|
||||
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
|
||||
|
||||
**Prohibited:**
|
||||
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
|
||||
- Macro cannot reference macros defined **after** it
|
||||
- No circular references (prevented by single-pass, ordered substitution)
|
||||
|
||||
### 5. Validation Requirements
|
||||
|
||||
Add validation to detect:
|
||||
- **Self-references:** Macro value contains reference to its own name
|
||||
- **Unknown macros:** After substitution, any remaining `${...}` references
|
||||
|
||||
Error messages should be clear:
|
||||
```
|
||||
macro 'foo' contains self-reference
|
||||
unknown macro '${bar}' in model.cmd
|
||||
```
|
||||
|
||||
### 6. Implementation Changes
|
||||
|
||||
#### Files to Modify
|
||||
|
||||
1. **[proxy/config/config.go](proxy/config/config.go)**
|
||||
- Line 19: Change `MacroList` type definition
|
||||
- Line 69: Update `Macros MacroList` field
|
||||
- Line 153-157: Update macro validation loop to work with ordered structure
|
||||
- Line 175-188: Update model-level macro validation
|
||||
- Line 181-188: **NEW** Implement proper macro merging respecting order
|
||||
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
|
||||
- Line 389-415: Update `validateMacro` to detect self-references
|
||||
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
|
||||
|
||||
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
|
||||
- Line 33: Update `Macros MacroList` field type
|
||||
|
||||
3. **All test files**
|
||||
- Update test fixtures to use ordered macro definitions
|
||||
- Ensure tests specify macro order explicitly
|
||||
|
||||
#### Core Algorithm
|
||||
|
||||
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
|
||||
|
||||
```go
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
|
||||
|
||||
// Add global macros first
|
||||
for _, entry := range config.Macros {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
|
||||
// Add model macros (can override global)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Add reserved MODEL_ID macro at the end
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Check if PORT macro is needed
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro to the end (highest priority)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Add this new helper function to replace `substituteMetadataMacros`:
|
||||
|
||||
```go
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7. Self-Reference Detection
|
||||
|
||||
Add to `validateMacro` function:
|
||||
|
||||
```go
|
||||
func validateMacro(name string, value any) error {
|
||||
// ... existing validation ...
|
||||
|
||||
// Check for self-reference
|
||||
if str, ok := value.(string); ok {
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(str, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Plan
|
||||
|
||||
### 1. Migration Tests
|
||||
- **Test:** All existing macro tests still pass after YAML library migration
|
||||
- **Files:** All `*_test.go` files with macro tests
|
||||
|
||||
### 2. Macro Order Tests
|
||||
|
||||
#### Test: Macro-in-macro substitution order
|
||||
```yaml
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: "echo ${B}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
|
||||
|
||||
#### Test: LIFO substitution order
|
||||
```yaml
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: "load ${full}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
|
||||
|
||||
#### Test: Model macro overrides global
|
||||
```yaml
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: "echo ${msg}"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
|
||||
|
||||
### 3. Reserved Macro Tests
|
||||
|
||||
#### Test: MODEL_ID substituted in macro
|
||||
```yaml
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: "${podman-llama} -m model.gguf"
|
||||
```
|
||||
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
|
||||
|
||||
### 4. Error Detection Tests
|
||||
|
||||
#### Test: Self-reference detection
|
||||
```yaml
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
```
|
||||
**Expected:** Error: `macro 'recursive' contains self-reference`
|
||||
|
||||
#### Test: Undefined macro reference
|
||||
```yaml
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
```
|
||||
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
|
||||
|
||||
### 5. Regression Tests
|
||||
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
|
||||
- Ensure all pass without modification (except test fixtures if needed)
|
||||
|
||||
## Checklist
|
||||
|
||||
### Phase 1: Data Structure Changes
|
||||
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
|
||||
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
|
||||
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
|
||||
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
|
||||
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
|
||||
- [ ] Implement helper functions:
|
||||
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
|
||||
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
|
||||
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
|
||||
|
||||
### Phase 2: Macro Validation Updates
|
||||
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
|
||||
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
|
||||
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
|
||||
- [ ] Test self-reference detection with new test case
|
||||
|
||||
### Phase 3: Macro Substitution Algorithm
|
||||
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
|
||||
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
|
||||
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
|
||||
- [ ] Substitute in metadata within same loop
|
||||
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
|
||||
- [ ] Ensure `PORT` is added after port assignment (if needed)
|
||||
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
|
||||
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
|
||||
|
||||
### Phase 4: Testing
|
||||
- [ ] Run `make test-dev` - fix any static checking errors
|
||||
- [ ] Add test: macro-in-macro basic substitution
|
||||
- [ ] Add test: LIFO substitution order with 3+ macro levels
|
||||
- [ ] Add test: MODEL_ID in global macro used by model
|
||||
- [ ] Add test: PORT in global macro used by model
|
||||
- [ ] Add test: model macro overrides global macro in substitution
|
||||
- [ ] Add test: self-reference detection error
|
||||
- [ ] Add test: undefined macro reference error
|
||||
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
|
||||
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
|
||||
|
||||
### Phase 5: Documentation
|
||||
- [ ] Update plan status in this file (mark completed)
|
||||
- [ ] Update CLAUDE.md if macro behavior needs documentation
|
||||
- [ ] Verify no new error messages need user documentation
|
||||
|
||||
## Bug Example (Original Issue)
|
||||
|
||||
```yaml
|
||||
macros:
|
||||
"podman-llama": >
|
||||
podman run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
|
||||
ghcr.io/ggml-org/llama.cpp:server-cuda
|
||||
|
||||
"standard-options": >
|
||||
--no-mmap --jinja
|
||||
|
||||
"kv8": >
|
||||
-fa on -ctk q8_0 -ctv q8_0
|
||||
```
|
||||
|
||||
**Current Bug:**
|
||||
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
|
||||
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
|
||||
|
||||
**After Fix:**
|
||||
- Macros substituted in LIFO order: `kv8` → `standard-options` → `podman-llama`
|
||||
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
|
||||
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
// created for issue: #252 https://github.com/mostlygeek/llama-swap/issues/252
|
||||
// this simple benchmark tool sends a lot of small chat completion requests to llama-swap
|
||||
// to make sure all the requests are accounted for.
|
||||
//
|
||||
// requests can be sent in parallel, and the tool will report the results.
|
||||
// usage: go run main.go -baseurl http://localhost:8080/v1 -model llama3 -requests 1000 -par 5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// ----- CLI arguments ----------------------------------------------------
|
||||
var (
|
||||
baseurl string
|
||||
modelName string
|
||||
totalRequests int
|
||||
parallelization int
|
||||
)
|
||||
|
||||
flag.StringVar(&baseurl, "baseurl", "http://localhost:8080/v1", "Base URL of the API (e.g., https://api.example.com)")
|
||||
flag.StringVar(&modelName, "model", "", "Model name to use")
|
||||
flag.IntVar(&totalRequests, "requests", 1, "Total number of requests to send")
|
||||
flag.IntVar(¶llelization, "par", 1, "Maximum number of concurrent requests")
|
||||
flag.Parse()
|
||||
|
||||
if baseurl == "" || modelName == "" {
|
||||
fmt.Println("Error: both -baseurl and -model are required.")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
if totalRequests <= 0 {
|
||||
fmt.Println("Error: -requests must be greater than 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
if parallelization <= 0 {
|
||||
fmt.Println("Error: -parallelization must be greater than 0.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// ----- HTTP client -------------------------------------------------------
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// ----- Tracking response codes -------------------------------------------
|
||||
statusCounts := make(map[int]int) // map[statusCode]count
|
||||
var mu sync.Mutex // protects statusCounts
|
||||
|
||||
// ----- Request queue (buffered channel) ----------------------------------
|
||||
requests := make(chan int, 10) // Buffered channel with capacity 10
|
||||
|
||||
// Goroutine to fill the request queue
|
||||
go func() {
|
||||
for i := 0; i < totalRequests; i++ {
|
||||
requests <- i + 1
|
||||
}
|
||||
close(requests)
|
||||
}()
|
||||
|
||||
// ----- Worker pool -------------------------------------------------------
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < parallelization; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for reqID := range requests {
|
||||
// Build request payload as a single line JSON string
|
||||
payload := `{"model":"` + modelName + `","max_tokens":100,"stream":false,"messages":[{"role":"user","content":"write a snake game in python"}]}`
|
||||
|
||||
// Send POST request
|
||||
req, err := http.NewRequest(http.MethodPost,
|
||||
fmt.Sprintf("%s/chat/completions", baseurl),
|
||||
bytes.NewReader([]byte(payload)))
|
||||
if err != nil {
|
||||
log.Printf("[worker %d][req %d] request creation error: %v", workerID, reqID, err)
|
||||
mu.Lock()
|
||||
statusCounts[-1]++
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[worker %d][req %d] HTTP request error: %v", workerID, reqID, err)
|
||||
mu.Lock()
|
||||
statusCounts[-1]++
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
// Record status code
|
||||
mu.Lock()
|
||||
statusCounts[resp.StatusCode]++
|
||||
mu.Unlock()
|
||||
}
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
// ----- Status ticker (prints every second) -------------------------------
|
||||
done := make(chan struct{})
|
||||
tickerDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
startTime := time.Now()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mu.Lock()
|
||||
// Compute how many requests have completed so far
|
||||
completed := 0
|
||||
for _, cnt := range statusCounts {
|
||||
completed += cnt
|
||||
}
|
||||
// Calculate duration and progress
|
||||
duration := time.Since(startTime)
|
||||
progress := completed * 100 / totalRequests
|
||||
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, progress)
|
||||
mu.Unlock()
|
||||
case <-done:
|
||||
duration := time.Since(startTime)
|
||||
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, 100)
|
||||
close(tickerDone)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for all workers to finish
|
||||
wg.Wait()
|
||||
close(done) // stops the status-update goroutine
|
||||
<-tickerDone // give ticker time to finish / print
|
||||
|
||||
// ----- Summary ------------------------------------------------------------
|
||||
fmt.Println("\n\n=== HTTP response code summary ===")
|
||||
mu.Lock()
|
||||
for code, cnt := range statusCounts {
|
||||
if code == -1 {
|
||||
fmt.Printf("Client-side errors (no HTTP response): %d\n", cnt)
|
||||
} else {
|
||||
fmt.Printf("%d : %d\n", code, cnt)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
**
|
||||
Test how exec.Cmd.CommandContext behaves under certain conditions:*
|
||||
|
||||
- process is killed externally, what happens with cmd.Wait() *
|
||||
✔︎ it returns. catches crashes.*
|
||||
|
||||
- process ignores SIGTERM*
|
||||
✔︎ `kill()` is called after cmd.WaitDelay*
|
||||
|
||||
- this process exits, what happens with children (kill -9 <this process' pid>)*
|
||||
x they stick around. have to be manually killed.*
|
||||
|
||||
- .WithTimeout()'s cancel is called *
|
||||
✔︎ process is killed after it ignores sigterm, cmd.Wait() catches it.*
|
||||
|
||||
- parent receives SIGINT/SIGTERM, what happens
|
||||
✔︎ waits for child process to exit, then exits gracefully.
|
||||
*/
|
||||
func main() {
|
||||
|
||||
// swap between these to use kill -9 <pid> on the cli to sim external crash
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
//ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
//cmd := exec.CommandContext(ctx, "sleep", "1")
|
||||
cmd := exec.CommandContext(ctx,
|
||||
"../../build/simple-responder_darwin_arm64",
|
||||
//"-ignore-sig-term", /* so it doesn't exit on receiving SIGTERM, test cmd.WaitTimeout */
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
// set a wait delay before signing sig kill
|
||||
cmd.WaitDelay = 500 * time.Millisecond
|
||||
cmd.Cancel = func() error {
|
||||
fmt.Println("✔︎ Cancel() called, sending SIGTERM")
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
//return nil
|
||||
|
||||
// this error is returned by cmd.Wait(), and can be used to
|
||||
// single an error when the process couldn't be normally terminated
|
||||
// but since a SIGTERM is sent, it's probably ok to return a nil
|
||||
// as WaitDelay timing out will override the any error set here.
|
||||
//
|
||||
// test by enabling/disabling -ignore-sig-term on the process
|
||||
// with -ignore-sig-term enabled, cmd.Wait() will have "signal: killed"
|
||||
// without it, it will show the "new error from cancel"
|
||||
return errors.New("error from cmd.Cancel()") // sets error returned by cmd.Wait()
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
fmt.Println("Error starting process:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// catch signals. Calls cancel() which will cause cmd.Wait() to return and
|
||||
// this program to eventually exit gracefully.
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
signal := <-sigChan
|
||||
fmt.Printf("✔︎ Received signal: %d, Killing process... with cancel before exiting\n", signal)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
fmt.Printf("✔︎ Parent Pid: %d, Process Pid: %d\n", os.Getpid(), cmd.Process.Pid)
|
||||
fmt.Println("✔︎ Process started, cmd.Wait() ... ")
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Println("✔︎ cmd.Wait returned, Error:", err)
|
||||
} else {
|
||||
fmt.Println("✔︎ cmd.Wait returned, Process exited on its own")
|
||||
}
|
||||
fmt.Println("✔︎ Child process exited, Done.")
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
|
||||
silent := flag.Bool("silent", false, "disable all logging")
|
||||
|
||||
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
// Create a new Gin router
|
||||
r := gin.New()
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
|
||||
// Check if streaming is requested
|
||||
// Query is checked instead of JSON body since that event stream conflicts with other tests
|
||||
isStreaming := c.Query("stream") == "true"
|
||||
|
||||
if isStreaming {
|
||||
// Set headers for streaming
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
// Send 10 "asdf" tokens
|
||||
for i := 0; i < 10; i++ {
|
||||
data := gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []gin.H{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": gin.H{
|
||||
"content": "asdf",
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
c.SSEvent("message", data)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// Send final data with usage info
|
||||
finalData := gin.H{
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
// add timings to simulate llama.cpp
|
||||
"timings": gin.H{
|
||||
"prompt_n": 25,
|
||||
"prompt_ms": 13,
|
||||
"predicted_n": 10,
|
||||
"predicted_ms": 17,
|
||||
"predicted_per_second": 10,
|
||||
},
|
||||
}
|
||||
c.SSEvent("message", finalData)
|
||||
c.Writer.Flush()
|
||||
|
||||
// Send [DONE]
|
||||
c.SSEvent("message", "[DONE]")
|
||||
c.Writer.Flush()
|
||||
} else {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||
"request_body": string(bodyBytes),
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
"timings": gin.H{
|
||||
"prompt_n": 25,
|
||||
"prompt_ms": 13,
|
||||
"predicted_n": 10,
|
||||
"predicted_ms": 17,
|
||||
"predicted_per_second": 10,
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// for issue #62 to check model name strips profile slug
|
||||
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
defer c.Request.Body.Close()
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if modelName != *expectedModel {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
// llama-server compatibility: /completion
|
||||
r.POST("/completion", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"responseMessage": *responseMessage,
|
||||
"usage": gin.H{
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 35,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// issue #41
|
||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||
// Parse the multipart form
|
||||
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the model from the form values
|
||||
model := c.Request.FormValue("model")
|
||||
|
||||
if model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the file from the form
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read the file content to get its size
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
fileSize := len(fileBytes)
|
||||
|
||||
// Return a JSON response with the model and transcription text including file size
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||
"model": model,
|
||||
|
||||
// expose some header values for testing
|
||||
"h_content_type": c.GetHeader("Content-Type"),
|
||||
"h_content_length": c.GetHeader("Content-Length"),
|
||||
})
|
||||
})
|
||||
|
||||
r.GET("/v1/audio/voices", func(c *gin.Context) {
|
||||
model := c.Query("model")
|
||||
c.JSON(http.StatusOK, gin.H{"voices": []string{"voice1"}, "model": model})
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
if echo == "" {
|
||||
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
}
|
||||
|
||||
// Parse the duration
|
||||
if delay == "" {
|
||||
delay = "100ms"
|
||||
}
|
||||
|
||||
t, err := time.ParseDuration(delay)
|
||||
if err != nil {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/plain")
|
||||
for _, char := range echo {
|
||||
c.Writer.Write([]byte(string(char)))
|
||||
c.Writer.Flush()
|
||||
|
||||
// wait
|
||||
<-time.After(t)
|
||||
}
|
||||
})
|
||||
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/env", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
|
||||
// Get environment variables
|
||||
envVars := os.Environ()
|
||||
|
||||
// Write each environment variable to the response
|
||||
for _, envVar := range envVars {
|
||||
c.String(200, envVar)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up the /health endpoint handler function
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||
})
|
||||
|
||||
address := "127.0.0.1:" + *port // Address with the specified port
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: address,
|
||||
Handler: r.Handler(),
|
||||
}
|
||||
|
||||
// Disable logging if the --silent flag is set
|
||||
if *silent {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
if !*silent {
|
||||
fmt.Printf("My PID: %d\n", os.Getpid())
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("simple-responder listening on %s\n", address)
|
||||
// service connections
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("simple-responder err: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown the server with
|
||||
// a timeout of 5 seconds.
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
countSigInt := 0
|
||||
|
||||
runloop:
|
||||
for {
|
||||
signal := <-sigChan
|
||||
switch signal {
|
||||
case syscall.SIGINT:
|
||||
countSigInt++
|
||||
if countSigInt > 1 {
|
||||
break runloop
|
||||
} else {
|
||||
log.Println("Received SIGINT, send another SIGINT to shutdown")
|
||||
}
|
||||
case syscall.SIGTERM:
|
||||
if *ignoreSigTerm {
|
||||
log.Println("Ignoring SIGTERM")
|
||||
} else {
|
||||
log.Println("Received SIGTERM, shutting down")
|
||||
break runloop
|
||||
}
|
||||
default:
|
||||
break runloop
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("simple-responder shutting down")
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
# wol-proxy
|
||||
|
||||
wol-proxy automatically wakes up a suspended llama-swap server using Wake-on-LAN when requests are received.
|
||||
|
||||
When a request arrives and llama-swap is unavailable, wol-proxy sends a WOL packet and holds the request until the server becomes available. If the server doesn't respond within the timeout period (default: 60 seconds), the request is dropped.
|
||||
|
||||
This utility helps conserve energy by allowing GPU-heavy servers to remain suspended when idle, as they can consume hundreds of watts even when not actively processing requests.
|
||||
|
||||
## Usage
|
||||
|
||||
```shell
|
||||
# minimal
|
||||
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080
|
||||
|
||||
# everything
|
||||
$ ./wol-proxy -mac BA:DC:0F:FE:E0:00 -upstream http://192.168.1.13:8080 \
|
||||
# use debug log level
|
||||
-log debug \
|
||||
# altenerative listening port
|
||||
-listen localhost:9999 \
|
||||
# seconds to hold requests waiting for upstream to be ready
|
||||
-timeout 30
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
`GET /status` - that's it. Everything else is proxied to the upstream server.
|
||||
@@ -0,0 +1,64 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Loading...</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.loader {
|
||||
text-align: center;
|
||||
}
|
||||
.stats {
|
||||
font-size: 18px;
|
||||
color: #333;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.stats-label {
|
||||
color: #666;
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="loader">
|
||||
<p>Waking up upstream server...</p>
|
||||
<div class="stats">
|
||||
<div><span class="stats-label">Time elapsed:</span> <span id="elapsed">0s</span></div>
|
||||
<div><span id="attempts"> </span></div>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
var startTime = Date.now();
|
||||
var attempts = 0;
|
||||
|
||||
setInterval(function() {
|
||||
var elapsed = (Date.now() - startTime) / 1000;
|
||||
document.getElementById('elapsed').textContent = elapsed.toFixed(1) + 's';
|
||||
}, 100);
|
||||
|
||||
// Check status every second
|
||||
setInterval(function() {
|
||||
attempts++;
|
||||
var dots = '.'.repeat((attempts % 10) || 10);
|
||||
document.getElementById('attempts').textContent = dots;
|
||||
|
||||
fetch('/status')
|
||||
.then(function(r) { return r.text(); })
|
||||
.then(function(t) {
|
||||
if (t.indexOf('status: ready') !== -1) {
|
||||
location.reload();
|
||||
}
|
||||
})
|
||||
.catch(function() {});
|
||||
}, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,333 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:embed index.html
|
||||
var loadingPageHTML string
|
||||
|
||||
var (
|
||||
flagMac = flag.String("mac", "", "mac address to send WoL packet to")
|
||||
flagUpstream = flag.String("upstream", "", "upstream proxy address to send requests to")
|
||||
flagListen = flag.String("listen", ":8080", "listen address to listen on")
|
||||
flagLog = flag.String("log", "info", "log level (debug, info, warn, error)")
|
||||
flagTimeout = flag.Int("timeout", 60, "seconds requests wait for upstream response before failing")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
switch *flagLog {
|
||||
case "debug":
|
||||
slog.SetLogLoggerLevel(slog.LevelDebug)
|
||||
case "info":
|
||||
slog.SetLogLoggerLevel(slog.LevelInfo)
|
||||
case "warn":
|
||||
slog.SetLogLoggerLevel(slog.LevelWarn)
|
||||
case "error":
|
||||
slog.SetLogLoggerLevel(slog.LevelError)
|
||||
default:
|
||||
slog.Error("invalid log level", "logLevel", *flagLog)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate flags
|
||||
if *flagListen == "" {
|
||||
slog.Error("listen address is required")
|
||||
return
|
||||
}
|
||||
|
||||
if *flagMac == "" {
|
||||
slog.Error("mac address is required")
|
||||
return
|
||||
}
|
||||
|
||||
if *flagTimeout < 1 {
|
||||
slog.Error("timeout must be greater than 0")
|
||||
return
|
||||
}
|
||||
|
||||
var upstreamURL *url.URL
|
||||
var err error
|
||||
// validate mac address
|
||||
if _, err = net.ParseMAC(*flagMac); err != nil {
|
||||
slog.Error("invalid mac address", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if *flagUpstream == "" {
|
||||
slog.Error("upstream proxy address is required")
|
||||
return
|
||||
} else {
|
||||
upstreamURL, err = url.ParseRequestURI(*flagUpstream)
|
||||
if err != nil {
|
||||
slog.Error("error parsing upstream url", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxy := newProxy(upstreamURL)
|
||||
server := &http.Server{
|
||||
Addr: *flagListen,
|
||||
Handler: proxy,
|
||||
}
|
||||
|
||||
// start the server
|
||||
go func() {
|
||||
slog.Info("server starting on", "address", *flagListen)
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
slog.Error("error starting server", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// graceful shutdown
|
||||
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
<-ctx.Done()
|
||||
server.Close()
|
||||
}
|
||||
|
||||
type upstreamStatus string
|
||||
|
||||
const (
|
||||
notready upstreamStatus = "not ready"
|
||||
ready upstreamStatus = "ready"
|
||||
)
|
||||
|
||||
type proxyServer struct {
|
||||
upstreamProxy *httputil.ReverseProxy
|
||||
failCount int
|
||||
statusMutex sync.RWMutex
|
||||
status upstreamStatus
|
||||
}
|
||||
|
||||
func newProxy(url *url.URL) *proxyServer {
|
||||
p := httputil.NewSingleHostReverseProxy(url)
|
||||
proxy := &proxyServer{
|
||||
upstreamProxy: p,
|
||||
status: notready,
|
||||
failCount: 0,
|
||||
}
|
||||
|
||||
// start a goroutine to monitor upstream status via SSE
|
||||
go func() {
|
||||
eventsUrl := url.Scheme + "://" + url.Host + "/api/events"
|
||||
client := &http.Client{
|
||||
Timeout: 0, // No timeout for SSE connection
|
||||
}
|
||||
|
||||
waitDuration := 10 * time.Second
|
||||
|
||||
for {
|
||||
slog.Debug("connecting to SSE endpoint", "url", eventsUrl)
|
||||
|
||||
req, err := http.NewRequest("GET", eventsUrl, nil)
|
||||
if err != nil {
|
||||
slog.Warn("failed to create SSE request", "error", err)
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(waitDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
slog.Error("failed to connect to SSE endpoint", "error", err)
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
slog.Warn("SSE endpoint returned non-OK status", "status", resp.StatusCode)
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// Successfully connected to SSE endpoint
|
||||
slog.Info("connected to SSE endpoint, upstream ready")
|
||||
proxy.setStatus(ready)
|
||||
proxy.resetFailures()
|
||||
|
||||
// Read from the SSE stream to detect disconnection
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
// use a fairly large buffer to avoid scanner errors when reading large SSE events
|
||||
buf := make([]byte, 0, 1024*1024*2)
|
||||
scanner.Buffer(buf, 1024*1024*2)
|
||||
events := 0
|
||||
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||
fmt.Print("Events: ")
|
||||
}
|
||||
for scanner.Scan() {
|
||||
if slog.Default().Enabled(context.Background(), slog.LevelDebug) {
|
||||
// Just read the events to keep connection alive
|
||||
// We don't need to process the event data
|
||||
events++
|
||||
fmt.Printf("%d, ", events)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
if err := scanner.Err(); err != nil {
|
||||
slog.Error("error reading from SSE stream", "error", err)
|
||||
}
|
||||
|
||||
// Connection closed or error occurred
|
||||
_ = resp.Body.Close()
|
||||
slog.Info("SSE connection closed, upstream not ready")
|
||||
proxy.setStatus(notready)
|
||||
proxy.incFail(1)
|
||||
|
||||
// Wait before reconnecting
|
||||
time.Sleep(waitDuration)
|
||||
}
|
||||
}()
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
func (p *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "GET" && r.URL.Path == "/status" {
|
||||
status := string(p.getStatus())
|
||||
failCount := p.getFailures()
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(200)
|
||||
fmt.Fprintf(w, "status: %s\n", status)
|
||||
fmt.Fprintf(w, "failures: %d\n", failCount)
|
||||
return
|
||||
}
|
||||
|
||||
if p.getStatus() == notready {
|
||||
path := r.URL.Path
|
||||
if strings.HasPrefix(path, "/api/events") {
|
||||
slog.Debug("Skipping wake up", "req", path)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("upstream not ready, sending magic packet", "req", path, "from", r.RemoteAddr)
|
||||
if err := sendMagicPacket(*flagMac); err != nil {
|
||||
slog.Warn("failed to send magic WoL packet", "error", err)
|
||||
}
|
||||
|
||||
// For root or UI path requests, return loading page with status polling
|
||||
// the web page will do the polling and redirect when ready
|
||||
if path == "/" || strings.HasPrefix(path, "/ui/") {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, loadingPageHTML)
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
timeout, cancel := context.WithTimeout(context.Background(), time.Duration(*flagTimeout)*time.Second)
|
||||
defer cancel()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeout.Done():
|
||||
slog.Info("timeout waiting for upstream to be ready")
|
||||
http.Error(w, "timeout", http.StatusRequestTimeout)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if p.getStatus() == ready {
|
||||
ticker.Stop()
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.upstreamProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (p *proxyServer) getStatus() upstreamStatus {
|
||||
p.statusMutex.RLock()
|
||||
defer p.statusMutex.RUnlock()
|
||||
return p.status
|
||||
}
|
||||
|
||||
func (p *proxyServer) setStatus(status upstreamStatus) {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.status = status
|
||||
}
|
||||
|
||||
func (p *proxyServer) incFail(num int) {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.failCount += num
|
||||
}
|
||||
|
||||
func (p *proxyServer) getFailures() int {
|
||||
p.statusMutex.RLock()
|
||||
defer p.statusMutex.RUnlock()
|
||||
return p.failCount
|
||||
}
|
||||
|
||||
func (p *proxyServer) resetFailures() {
|
||||
p.statusMutex.Lock()
|
||||
defer p.statusMutex.Unlock()
|
||||
p.failCount = 0
|
||||
}
|
||||
|
||||
func sendMagicPacket(macAddr string) error {
|
||||
hwAddr, err := net.ParseMAC(macAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hwAddr) != 6 {
|
||||
return errors.New("invalid MAC address")
|
||||
}
|
||||
|
||||
// Create the magic packet.
|
||||
packet := make([]byte, 102)
|
||||
// Add 6 bytes of 0xFF.
|
||||
for i := 0; i < 6; i++ {
|
||||
packet[i] = 0xFF
|
||||
}
|
||||
// Repeat the MAC address 16 times.
|
||||
for i := 1; i <= 16; i++ {
|
||||
copy(packet[i*6:], hwAddr)
|
||||
}
|
||||
|
||||
// Send the packet using UDP.
|
||||
addr := net.UDPAddr{
|
||||
IP: net.IPv4bcast,
|
||||
Port: 9,
|
||||
}
|
||||
conn, err := net.DialUDP("udp", nil, &addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,371 @@
|
||||
{
|
||||
"$schema": "https://json-schema.org/draft-07/schema#",
|
||||
"$id": "llama-swap-config-schema.json",
|
||||
"title": "llama-swap configuration",
|
||||
"description": "Configuration file for llama-swap",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"models"
|
||||
],
|
||||
"definitions": {
|
||||
"macros": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"minLength": 0,
|
||||
"maxLength": 1024
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
},
|
||||
"propertyNames": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 64,
|
||||
"pattern": "^[a-zA-Z0-9_-]+$",
|
||||
"not": {
|
||||
"enum": [
|
||||
"PORT",
|
||||
"MODEL_ID"
|
||||
]
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of string substitutions. Macros are reusable snippets used in model cmd, cmdStop, proxy, checkEndpoint, filters.stripParams. Macro names must be <64 chars, match ^[a-zA-Z0-9_-]+$, and not be PORT or MODEL_ID. Values can be string, number, or boolean. Macros can reference other macros defined before them."
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"healthCheckTimeout": {
|
||||
"type": "integer",
|
||||
"minimum": 15,
|
||||
"default": 120,
|
||||
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
||||
},
|
||||
"logLevel": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"debug",
|
||||
"info",
|
||||
"warn",
|
||||
"error"
|
||||
],
|
||||
"default": "info",
|
||||
"description": "Sets the logging value. Valid values: debug, info, warn, error."
|
||||
},
|
||||
"logTimeFormat": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"",
|
||||
"ansic",
|
||||
"unixdate",
|
||||
"rubydate",
|
||||
"rfc822",
|
||||
"rfc822z",
|
||||
"rfc850",
|
||||
"rfc1123",
|
||||
"rfc1123z",
|
||||
"rfc3339",
|
||||
"rfc3339nano",
|
||||
"kitchen",
|
||||
"stamp",
|
||||
"stampmilli",
|
||||
"stampmicro",
|
||||
"stampnano"
|
||||
],
|
||||
"default": "",
|
||||
"description": "Enables and sets the logging timestamp format. Valid values: \"\", \"ansic\", \"unixdate\", \"rubydate\", \"rfc822\", \"rfc822z\", \"rfc850\", \"rfc1123\", \"rfc1123z\", \"rfc3339\", \"rfc3339nano\", \"kitchen\", \"stamp\", \"stampmilli\", \"stampmicro\", and \"stampnano\". For more info, read: https://pkg.go.dev/time#pkg-constants"
|
||||
},
|
||||
"metricsMaxInMemory": {
|
||||
"type": "integer",
|
||||
"default": 1000,
|
||||
"description": "Maximum number of metrics to keep in memory. Controls how many metrics are stored before older ones are discarded."
|
||||
},
|
||||
"captureBuffer": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 5,
|
||||
"description": "Size in megabytes of the buffer for storing request/response captures. Set to 0 to disable captures."
|
||||
},
|
||||
"startPort": {
|
||||
"type": "integer",
|
||||
"default": 5800,
|
||||
"description": "Starting port number for the automatic ${PORT} macro. The ${PORT} macro is incremented for every model that uses it."
|
||||
},
|
||||
"sendLoadingState": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Inject loading status updates into the reasoning field. When true, a stream of loading messages will be sent to the client."
|
||||
},
|
||||
"includeAliasesInList": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Present aliases within the /v1/models OpenAI API listing. when true, model aliases will be output to the API model listing duplicating all fields except for Id so chat UIs can use the alias equivalent to the original."
|
||||
},
|
||||
"macros": {
|
||||
"$ref": "#/definitions/macros"
|
||||
},
|
||||
"models": {
|
||||
"type": "object",
|
||||
"description": "A dictionary of model configurations. Each key is a model's ID. Model settings have defaults if not defined. The model's ID is available as ${MODEL_ID}.",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"cmd"
|
||||
],
|
||||
"properties": {
|
||||
"macros": {
|
||||
"$ref": "#/definitions/macros"
|
||||
},
|
||||
"cmd": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "Command to run to start the inference server. Macros can be used. Comments allowed with |."
|
||||
},
|
||||
"cmdStop": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Command to run to stop the model gracefully. Uses ${PID} macro for upstream process id. If empty, default shutdown behavior is used."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"maxLength": 128,
|
||||
"description": "Display name for the model. Used in v1/models API response."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"maxLength": 1024,
|
||||
"description": "Description for the model. Used in v1/models API response."
|
||||
},
|
||||
"env": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-Z_][A-Z0-9_]*=.*$"
|
||||
},
|
||||
"default": [],
|
||||
"description": "Array of environment variables to inject into cmd's environment. Each value is a string in ENV_NAME=value format."
|
||||
},
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"default": "http://localhost:${PORT}",
|
||||
"format": "uri",
|
||||
"description": "URL where llama-swap routes API requests. If custom port is used in cmd, this must be set."
|
||||
},
|
||||
"aliases": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Alternative model names for this configuration. Must be unique globally."
|
||||
},
|
||||
"checkEndpoint": {
|
||||
"type": "string",
|
||||
"default": "/health",
|
||||
"pattern": "^/.*$|^none$",
|
||||
"description": "URL path to check if the server is ready. Use 'none' to skip health checking."
|
||||
},
|
||||
"ttl": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
|
||||
},
|
||||
"useModelName": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Override the model name sent to upstream server. Useful if upstream expects a different name."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
},
|
||||
"setParamsByID": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
},
|
||||
"default": {},
|
||||
"description": "Dictionary mapping requested model IDs (or aliases) to parameters to set/override in requests. Applied after setParams and can override those values. Useful with aliases to vary behaviour depending on which alias the client used (e.g. different reasoning_effort per alias). Keys support ${MODEL_ID} macro substitution. Protected params like 'model' cannot be overridden."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings. Supports stripParams, setParams, and setParamsByID."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of arbitrary values included in /v1/models. Can contain complex types. Only passed through in /v1/models responses."
|
||||
},
|
||||
"concurrencyLimit": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Overrides allowed number of active parallel requests to a model. 0 uses internal default of 10. >0 overrides default. Requests exceeding limit get HTTP 429."
|
||||
},
|
||||
"sendLoadingState": {
|
||||
"type": "boolean",
|
||||
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
|
||||
},
|
||||
"unlisted": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"groups": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"members"
|
||||
],
|
||||
"properties": {
|
||||
"swap": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls model swapping behaviour within the group. True: only one model runs at a time. False: all models can run together."
|
||||
},
|
||||
"exclusive": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Controls how the group affects other groups. True: causes all other groups to unload when this group runs a model. False: does not affect other groups."
|
||||
},
|
||||
"persistent": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Prevents other groups from unloading the models in this group. Does not affect individual model behaviour."
|
||||
},
|
||||
"members": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Array of model IDs that are members of this group. Model IDs must be defined in models."
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "A dictionary of group settings. Provides advanced controls over model swapping behaviour. Model IDs must be defined in models. A model can only be a member of one group. Behaviour controlled via swap, exclusive, persistent."
|
||||
},
|
||||
"hooks": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"on_startup": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"preload": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": [],
|
||||
"description": "List of model IDs to load on startup. Model names must match keys in models. When preloading multiple models, define a group to prevent swapping."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Actions to perform on startup. Only supported action is preload."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||
},
|
||||
"logToStdout": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"proxy",
|
||||
"upstream",
|
||||
"both",
|
||||
"none"
|
||||
],
|
||||
"default": "proxy",
|
||||
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
|
||||
},
|
||||
"apiKeys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
|
||||
},
|
||||
"peers": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"proxy",
|
||||
"models"
|
||||
],
|
||||
"properties": {
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
|
||||
},
|
||||
"apiKey": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"description": "A list of models served by the peer."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
|
||||
}
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
}
|
||||
}
|
||||
}
|
||||
+429
-72
@@ -1,84 +1,441 @@
|
||||
# Seconds to wait for llama.cpp to be available to serve requests
|
||||
# Default (and minimum): 15 seconds
|
||||
healthCheckTimeout: 15
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
# 💡 Tip - Use an LLM with this file!
|
||||
# ====================================
|
||||
# This example configuration is written to be LLM friendly. Try
|
||||
# copying this file into an LLM and asking it to explain or generate
|
||||
# sections for you.
|
||||
# ====================================
|
||||
|
||||
# Log HTTP requests helpful for troubleshoot, defaults to False
|
||||
logRequests: true
|
||||
# Usage notes:
|
||||
# - Below are all the available configuration options for llama-swap.
|
||||
# - Settings noted as "required" must be in your configuration file
|
||||
# - Settings noted as "optional" can be omitted
|
||||
|
||||
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||
# - optional, default: 120
|
||||
# - minimum value is 15 seconds, anything less will be set to this value
|
||||
healthCheckTimeout: 500
|
||||
|
||||
# logLevel: sets the logging value
|
||||
# - optional, default: info
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# captureBuffer: how many MBs to allocate for storing request/response captures
|
||||
# - optional, default: 10
|
||||
# - set to 0 to disable
|
||||
captureBuffer: 15
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||
# field
|
||||
# - optional, default: false
|
||||
# - when true, a stream of loading messages will be sent to the client in the
|
||||
# reasoning field so chat UIs can show that loading is in progress.
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||
# - useful for reducing common configuration settings
|
||||
# - macro names are strings and must be less than 64 characters
|
||||
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||
# - env macros are substituted first, before regular macros
|
||||
# - if the env var is not set, config loading will fail with an error
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
/path/to/llama-server/llama-server-ec9e0301
|
||||
--port ${PORT}
|
||||
|
||||
"default_ctx": 4096
|
||||
|
||||
# Example of macro-in-macro usage. macros can contain other macros
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# Example of environment variable macros
|
||||
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||
# - useful for paths, secrets, or machine-specific configuration
|
||||
"models_dir": "${env.HOME}/models"
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
|
||||
# use environment variable macros to keep secrets out of the config
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
# - model settings have default values that are used if they are not defined here
|
||||
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||
# - below are examples of the all the settings a model can have
|
||||
models:
|
||||
"llama":
|
||||
cmd: >
|
||||
models/llama-server-osx
|
||||
--port 9001
|
||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
proxy: http://127.0.0.1:9001
|
||||
# keys are the model names used in API requests
|
||||
"gpt-oss-120b":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
# - model level macros follow the same rules as global macros
|
||||
macros:
|
||||
"default_ctx": 16384
|
||||
"temp": 0.7
|
||||
|
||||
# list of model name aliases this llama.cpp instance can serve
|
||||
aliases:
|
||||
- gpt-4o-mini
|
||||
# cmd: the command to run to start the inference server.
|
||||
# - required
|
||||
# - it is just a string, similar to what you would run on the CLI
|
||||
# - using `|` allows for comments in the command, these will be parsed out
|
||||
# - macros can be used within cmd
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/gpt-oss-120B.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
# check this path for a HTTP 200 response for the server to be ready
|
||||
checkEndpoint: /health
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "gpt-oss 120B"
|
||||
|
||||
# unload model after 5 seconds
|
||||
ttl: 5
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A thinking model from OpenAI"
|
||||
|
||||
"qwen":
|
||||
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9002
|
||||
aliases:
|
||||
- gpt-3.5-turbo
|
||||
|
||||
# Embedding example with Nomic
|
||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
||||
"nomic":
|
||||
proxy: http://127.0.0.1:9005
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9005
|
||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
||||
--ctx-size 8192
|
||||
--batch-size 8192
|
||||
--rope-scaling yarn
|
||||
--rope-freq-scale 0.75
|
||||
-ngl 99
|
||||
--embeddings
|
||||
|
||||
# Reranking example with bge-reranker
|
||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
||||
"bge-reranker":
|
||||
proxy: http://127.0.0.1:9006
|
||||
cmd: >
|
||||
models/llama-server-osx --port 9006
|
||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
||||
--ctx-size 8192
|
||||
--reranking
|
||||
|
||||
|
||||
"simple":
|
||||
# example of setting environment variables
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
# - each value is a single string
|
||||
# - in the format: ENV_NAME=value
|
||||
env:
|
||||
- CUDA_VISIBLE_DEVICES=0,1
|
||||
- env1=hello
|
||||
cmd: build/simple-responder --port 8999
|
||||
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||
|
||||
# proxy: the URL where llama-swap routes API requests
|
||||
# - optional, default: http://localhost:${PORT}
|
||||
# - if you used ${PORT} in cmd this can be omitted
|
||||
# - if you use a custom port in cmd this *must* be set
|
||||
proxy: http://127.0.0.1:8999
|
||||
unlisted: true
|
||||
|
||||
# use "none" to skip check. Caution this may cause some requests to fail
|
||||
# until the upstream server is ready for traffic
|
||||
checkEndpoint: none
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
# - all requests wait until the endpoint is ready or fails
|
||||
# - use "none" to skip endpoint health checking
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# don't use these, just for testing if things are broken
|
||||
"broken":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
||||
proxy: http://127.0.0.1:8999
|
||||
unlisted: true
|
||||
"broken_timeout":
|
||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||
proxy: http://127.0.0.1:9000
|
||||
unlisted: true
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: 0
|
||||
# - ttl values must be a value greater than 0
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
# creating a coding profile with models for code generation and general questions
|
||||
profiles:
|
||||
coding:
|
||||
- "qwen"
|
||||
- "llama"
|
||||
# useModelName: override the model name that is sent to upstream server
|
||||
# - optional, default: ""
|
||||
# - useful for when the upstream server expects a specific model name that
|
||||
# is different from the model's ID
|
||||
useModelName: "openai/gpt-oss-120B"
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as peer filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for server side enforcement of sampling parameters
|
||||
# - the `model` parameter can never be removed
|
||||
# - can be any JSON key in the request body
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for enforcing specific parameter values
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - always runs for the model
|
||||
setParams:
|
||||
# Example: enforce specific sampling parameters
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
|
||||
# setParamsByID: a dictionary of parameters to set based the model ID
|
||||
# - optional, default: empty dictionary
|
||||
# - combine with aliases to create variant behaviour without reloading the model
|
||||
# - parameters are set in the request body JSON
|
||||
# - run after setParams so it will override any settings
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
# - model aliases will be automatically created for each key
|
||||
setParamsByID:
|
||||
"${MODEL_ID}":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: medium
|
||||
"${MODEL_ID}:high":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: low
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
# - metadata is only passed through in /v1/models responses
|
||||
metadata:
|
||||
# port will remain an integer
|
||||
port: ${PORT}
|
||||
|
||||
# the ${temp} macro will remain a float
|
||||
temperature: ${temp}
|
||||
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||
|
||||
a_list:
|
||||
- 1
|
||||
- 1.23
|
||||
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||
|
||||
an_obj:
|
||||
a: "1"
|
||||
b: 2
|
||||
# objects can contain complex types with macro substitution
|
||||
# becomes: c: [0.7, false, "model: llama"]
|
||||
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||
|
||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||
# - optional, default: 0
|
||||
# - useful for limiting the number of active parallel requests a model can process
|
||||
# - must be set per model
|
||||
# - any number greater than 0 will override the internal default value of 10
|
||||
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
# - optional, default: false
|
||||
# - unlisted models do not show up in /v1/models api requests
|
||||
# - can be requested as normal through all apis
|
||||
unlisted: true
|
||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker example:
|
||||
# container runtimes like Docker and Podman can be used reliably with
|
||||
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
cmd: |
|
||||
docker run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# cmdStop: command to run to stop the model gracefully
|
||||
# - optional, default: ""
|
||||
# - useful for stopping commands managed by another system
|
||||
# - the upstream's process id is available in the ${PID} macro
|
||||
#
|
||||
# When empty, llama-swap has this default behaviour:
|
||||
# - on POSIX systems: a SIGTERM signal is sent
|
||||
# - on Windows, calls taskkill to stop the process
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# groups: a dictionary of group settings
|
||||
# - optional, default: empty dictionary
|
||||
# - provides advanced controls over model swapping behaviour
|
||||
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
||||
# - model IDs must be defined in the Models section
|
||||
# - a model can only be a member of one group
|
||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||
groups:
|
||||
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
||||
# to run a time across the whole llama-swap instance
|
||||
"group1":
|
||||
# swap: controls the model swapping behaviour in within the group
|
||||
# - optional, default: true
|
||||
# - true : only one model is allowed to run at a time
|
||||
# - false: all models can run together, no swapping
|
||||
swap: true
|
||||
|
||||
# exclusive: controls how the group affects other groups
|
||||
# - optional, default: true
|
||||
# - true: causes all other groups to unload when this group runs a model
|
||||
# - false: does not affect other groups
|
||||
exclusive: true
|
||||
|
||||
# members references the models defined above
|
||||
# required
|
||||
members:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
|
||||
# Example:
|
||||
# - in group2 all models can run at the same time
|
||||
# - when a different group is loaded it causes all running models in this group to unload
|
||||
"group2":
|
||||
swap: false
|
||||
|
||||
# exclusive: false does not unload other groups when a model in group2 is requested
|
||||
# - the models in group2 will be loaded but will not unload any other groups
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
- "modelA"
|
||||
- "modelB"
|
||||
|
||||
# Example:
|
||||
# - a persistent group, prevents other groups from unloading it
|
||||
"forever":
|
||||
# persistent: prevents over groups from unloading the models in this group
|
||||
# - optional, default: false
|
||||
# - does not affect individual model behaviour
|
||||
persistent: true
|
||||
|
||||
# set swap/exclusive to false to prevent swapping inside the group
|
||||
# and the unloading of other groups
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
hooks:
|
||||
# on_startup: a dictionary of actions to perform on startup
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported action is preload
|
||||
on_startup:
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
# - can be a string or a macro
|
||||
apiKey: ${env.OPENROUTER_API_KEY}
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
# filters: a dictionary of filter settings for peer requests
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as model filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for removing parameters that the peer doesn't support
|
||||
# - the `model` parameter can never be removed
|
||||
stripParams: "temperature, top_p"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for injecting provider-specific settings like data retention policies
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce zero-data-retention for OpenRouter
|
||||
provider:
|
||||
data_collection: "deny"
|
||||
zdr: true
|
||||
|
||||
Executable
+164
@@ -0,0 +1,164 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
# use this to test locally, example:
|
||||
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
|
||||
# you need read:package scope on the token. Generate a personal access token with
|
||||
# the scopes: gist, read:org, repo, write:packages
|
||||
# then: gh auth login (and copy/paste the new token)
|
||||
|
||||
LOG_DEBUG=${LOG_DEBUG:-0}
|
||||
DEBUG_ABORT_BUILD=${DEBUG_ABORT_BUILD:-}
|
||||
|
||||
log_debug() {
|
||||
if [ "$LOG_DEBUG" = "1" ]; then
|
||||
echo "[DEBUG] $*"
|
||||
fi
|
||||
}
|
||||
|
||||
log_info() {
|
||||
echo "[INFO] $*"
|
||||
}
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu" "rocm")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
|
||||
log_info "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set llama.cpp base image, customizable using the BASE_LLAMACPP_IMAGE environment
|
||||
# variable, this permits testing with forked llama.cpp repositories
|
||||
BASE_IMAGE=${BASE_LLAMACPP_IMAGE:-ghcr.io/ggml-org/llama.cpp}
|
||||
SD_IMAGE=${BASE_SDCPP_IMAGE:-ghcr.io/leejet/stable-diffusion.cpp}
|
||||
|
||||
# Set llama-swap repository, automatically uses GITHUB_REPOSITORY variable
|
||||
# to enable easy container builds on forked repos
|
||||
LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
||||
|
||||
# the most recent llama-swap tag
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
# Fetches the most recent llama.cpp tag matching the given prefix
|
||||
# Handles pagination to search beyond the first 100 results
|
||||
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
|
||||
# Returns: the version number extracted from the tag
|
||||
fetch_llama_tag() {
|
||||
local tag_prefix=$1
|
||||
local page=1
|
||||
local per_page=100
|
||||
|
||||
while true; do
|
||||
log_debug "Fetching page $page for tag prefix: $tag_prefix"
|
||||
|
||||
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
|
||||
|
||||
# Check for API errors
|
||||
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
|
||||
local error_msg=$(echo "$response" | jq -r '.message')
|
||||
log_info "GitHub API error: $error_msg"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Check if response is empty array (no more pages)
|
||||
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
|
||||
log_debug "No more pages (empty response)"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract matching tag from this page
|
||||
local found_tag=$(echo "$response" | jq -r \
|
||||
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
|
||||
| sort -r | head -n1)
|
||||
|
||||
if [ -n "$found_tag" ]; then
|
||||
log_debug "Found tag: $found_tag on page $page"
|
||||
echo "$found_tag" | awk -F '-' '{print $NF}'
|
||||
return 0
|
||||
fi
|
||||
|
||||
page=$((page + 1))
|
||||
|
||||
# Safety limit to prevent infinite loops
|
||||
if [ $page -gt 50 ]; then
|
||||
log_info "Reached pagination safety limit (50 pages)"
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
LCPP_TAG=$(fetch_llama_tag "server")
|
||||
BASE_TAG=server-${LCPP_TAG}
|
||||
else
|
||||
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
|
||||
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||
fi
|
||||
|
||||
SD_TAG=master-${ARCH}
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
log_info "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
else
|
||||
log_info "LCPP_TAG: $LCPP_TAG"
|
||||
fi
|
||||
|
||||
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
|
||||
log_info "Abort: DEBUG_ABORT_BUILD set"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for CONTAINER_TYPE in non-root root; do
|
||||
CONTAINER_TAG="ghcr.io/${LS_REPO}:v${LS_VER}-${ARCH}-${LCPP_TAG}"
|
||||
CONTAINER_LATEST="ghcr.io/${LS_REPO}:${ARCH}"
|
||||
USER_UID=0
|
||||
USER_GID=0
|
||||
USER_HOME=/root
|
||||
|
||||
if [ "$CONTAINER_TYPE" == "non-root" ]; then
|
||||
CONTAINER_TAG="${CONTAINER_TAG}-non-root"
|
||||
CONTAINER_LATEST="${CONTAINER_LATEST}-non-root"
|
||||
USER_UID=10001
|
||||
USER_GID=10001
|
||||
USER_HOME=/app
|
||||
fi
|
||||
|
||||
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
docker build --provenance=false -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||
|
||||
# For architectures with stable-diffusion.cpp support, layer sd-server on top
|
||||
case "$ARCH" in
|
||||
"musa" | "vulkan")
|
||||
log_info "Adding sd-server to $CONTAINER_TAG"
|
||||
docker build --provenance=false -f llama-swap-sd.Containerfile \
|
||||
--build-arg BASE=${CONTAINER_TAG} \
|
||||
--build-arg SD_IMAGE=${SD_IMAGE} --build-arg SD_TAG=${SD_TAG} \
|
||||
--build-arg UID=${USER_UID} --build-arg GID=${USER_GID} \
|
||||
-t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . ;;
|
||||
esac
|
||||
|
||||
if [ "$PUSH_IMAGES" == "true" ]; then
|
||||
docker push ${CONTAINER_TAG}
|
||||
docker push ${CONTAINER_LATEST}
|
||||
fi
|
||||
done
|
||||
@@ -0,0 +1,33 @@
|
||||
healthCheckTimeout: 300
|
||||
logRequests: true
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
"smollm2":
|
||||
proxy: "http://127.0.0.1:9999"
|
||||
cmd: >
|
||||
/app/llama-server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
--port 9999
|
||||
|
||||
z-image:
|
||||
checkEndpoint: /
|
||||
cmd: |
|
||||
/app/sd-server
|
||||
--listen-port 9999
|
||||
--diffusion-fa
|
||||
--diffusion-model /models/z_image_turbo-Q8_0.gguf
|
||||
--vae /models/ae.safetensors
|
||||
--llm /models/qwen3-4b-instruct-2507-q8_0.gguf
|
||||
--offload-to-cpu
|
||||
--cfg-scale 1.0
|
||||
--height 512 --width 512
|
||||
--steps 8
|
||||
aliases: [gpt-image-1,dall-e-2,dall-e-3,gpt-image-1-mini,gpt-image-1.5]
|
||||
@@ -0,0 +1,11 @@
|
||||
ARG SD_IMAGE=ghcr.io/leejet/stable-diffusion.cpp
|
||||
ARG SD_TAG=master-vulkan
|
||||
ARG BASE=llama-swap:latest
|
||||
|
||||
FROM ${SD_IMAGE}:${SD_TAG} AS sd-source
|
||||
FROM ${BASE}
|
||||
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
|
||||
COPY --from=sd-source --chown=${UID}:${GID} /sd-server /app/sd-server
|
||||
@@ -0,0 +1,44 @@
|
||||
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# has to be after the FROM
|
||||
ARG LS_VER=170
|
||||
ARG LS_REPO=mostlygeek/llama-swap
|
||||
|
||||
# Set default UID/GID arguments
|
||||
ARG UID=10001
|
||||
ARG GID=10001
|
||||
ARG USER_HOME=/app
|
||||
|
||||
# Add user/group
|
||||
ENV HOME=$USER_HOME
|
||||
RUN if [ $UID -ne 0 ]; then \
|
||||
if [ $GID -ne 0 ]; then \
|
||||
groupadd --system --gid $GID app; \
|
||||
fi; \
|
||||
useradd --system --uid $UID --gid $GID \
|
||||
--home $USER_HOME app; \
|
||||
fi
|
||||
|
||||
# Handle paths
|
||||
RUN mkdir --parents $HOME /app
|
||||
RUN chown --recursive $UID:$GID $HOME /app
|
||||
|
||||
# Switch user
|
||||
USER $UID:$GID
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Add /app to PATH
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
RUN \
|
||||
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
|
||||
rm "llama-swap_${LS_VER}_linux_amd64.tar.gz"
|
||||
|
||||
COPY --chown=$UID:$GID config.example.yaml /app/config.yaml
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
|
Before Width: | Height: | Size: 261 KiB After Width: | Height: | Size: 261 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 351 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 198 KiB |
@@ -0,0 +1,467 @@
|
||||
# config.yaml
|
||||
|
||||
llama-swap is designed to be very simple: one binary, one configuration file.
|
||||
|
||||
## minimal viable config
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} --model /path/to/model.gguf
|
||||
```
|
||||
|
||||
This is enough to launch `llama-server` to serve `model1`. Of course, llama-swap is about making it possible to serve many models:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/model.gguf
|
||||
model2:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/another_model.gguf
|
||||
model3:
|
||||
cmd: llama-server --port ${PORT} -m /path/to/third_model.gguf
|
||||
```
|
||||
|
||||
With this configuration models will be hot swapped and loaded on demand. The special `${PORT}` macro provides a unique port per model. Useful if you want to run multiple models at the same time with the `groups` feature.
|
||||
|
||||
## Advanced control with `cmd`
|
||||
|
||||
llama-swap is also about customizability. You can use any CLI flag available:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
model1:
|
||||
cmd: | # support for multi-line
|
||||
llama-server --PORT ${PORT} -m /path/to/model.gguf
|
||||
--ctx-size 8192
|
||||
--jinja
|
||||
--cache-type-k q8_0
|
||||
--cache-type-v q8_0
|
||||
```
|
||||
|
||||
## Support for any OpenAI API compatible server
|
||||
|
||||
llama-swap supports any OpenAI API compatible server. If you can run it on the CLI llama-swap will be able to manage it. Even if it's run in Docker or Podman containers.
|
||||
|
||||
```yaml
|
||||
models:
|
||||
"Q3-30B-CODER-VLLM":
|
||||
name: "Qwen3 30B Coder vllm AWQ (Q3-30B-CODER-VLLM)"
|
||||
# cmdStop provides a reliable way to stop containers
|
||||
cmdStop: docker stop vllm-coder
|
||||
cmd: |
|
||||
docker run --init --rm --name vllm-coder
|
||||
--runtime=nvidia --gpus '"device=2,3"'
|
||||
--shm-size=16g
|
||||
-v /mnt/nvme/vllm-cache:/root/.cache
|
||||
-v /mnt/ssd-extra/models:/models -p ${PORT}:8000
|
||||
vllm/vllm-openai:v0.10.0
|
||||
--model "/models/cpatonn/Qwen3-Coder-30B-A3B-Instruct-AWQ"
|
||||
--served-model-name "Q3-30B-CODER-VLLM"
|
||||
--enable-expert-parallel
|
||||
--swap-space 16
|
||||
--max-num-seqs 512
|
||||
--max-model-len 65536
|
||||
--max-seq-len-to-capture 65536
|
||||
--gpu-memory-utilization 0.9
|
||||
--tensor-parallel-size 2
|
||||
--trust-remote-code
|
||||
```
|
||||
|
||||
## Many more features..
|
||||
|
||||
llama-swap supports many more features to customize how you want to manage your environment.
|
||||
|
||||
| Feature | Description |
|
||||
| --------- | ---------------------------------------------- |
|
||||
| `ttl` | automatic unloading of models after a timeout |
|
||||
| `macros` | reusable snippets to use in configurations |
|
||||
| `groups` | run multiple models at a time |
|
||||
| `hooks` | event driven functionality |
|
||||
| `env` | define environment variables per model |
|
||||
| `aliases` | serve a model with different names |
|
||||
| `filters` | modify requests before sending to the upstream |
|
||||
| `...` | And many more tweaks |
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
> [!NOTE]
|
||||
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
|
||||
|
||||
```yaml
|
||||
# add this modeline for validation in vscode
|
||||
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
|
||||
#
|
||||
# llama-swap YAML configuration example
|
||||
# -------------------------------------
|
||||
#
|
||||
# 💡 Tip - Use an LLM with this file!
|
||||
# ====================================
|
||||
# This example configuration is written to be LLM friendly. Try
|
||||
# copying this file into an LLM and asking it to explain or generate
|
||||
# sections for you.
|
||||
# ====================================
|
||||
|
||||
# Usage notes:
|
||||
# - Below are all the available configuration options for llama-swap.
|
||||
# - Settings noted as "required" must be in your configuration file
|
||||
# - Settings noted as "optional" can be omitted
|
||||
|
||||
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||
# - optional, default: 120
|
||||
# - minimum value is 15 seconds, anything less will be set to this value
|
||||
healthCheckTimeout: 500
|
||||
|
||||
# logLevel: sets the logging value
|
||||
# - optional, default: info
|
||||
# - Valid log levels: debug, info, warn, error
|
||||
logLevel: info
|
||||
|
||||
# logTimeFormat: enables and sets the logging timestamp format
|
||||
# - optional, default (disabled): ""
|
||||
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
|
||||
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
|
||||
# "stamp", "stampmilli", "stampmicro", and "stampnano".
|
||||
# - For more info, read: https://pkg.go.dev/time#pkg-constants
|
||||
logTimeFormat: ""
|
||||
|
||||
# logToStdout: controls what is logged to stdout
|
||||
# - optional, default: "proxy"
|
||||
# - valid values:
|
||||
# - "proxy": logs generated by llama-swap when swapping models,
|
||||
# handling requests, etc.
|
||||
# - "upstream": a copy of an upstream processes stdout logs
|
||||
# - "both": both the proxy and upstream logs interleaved together
|
||||
# - "none": no logs are ever written to stdout
|
||||
logToStdout: "proxy"
|
||||
|
||||
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||
# - optional, default: 1000
|
||||
# - controls how many metrics are stored in memory before older ones are discarded
|
||||
# - useful for limiting memory usage when processing large volumes of metrics
|
||||
metricsMaxInMemory: 1000
|
||||
|
||||
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||
# - optional, default: 5800
|
||||
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||
# - it is automatically incremented for every model that uses it
|
||||
startPort: 10001
|
||||
|
||||
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||
# field
|
||||
# - optional, default: false
|
||||
# - when true, a stream of loading messages will be sent to the client in the
|
||||
# reasoning field so chat UIs can show that loading is in progress.
|
||||
# - see #366 for more details
|
||||
sendLoadingState: true
|
||||
|
||||
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
|
||||
# - optional, default: false
|
||||
# - when true, model aliases will be output to the API model listing duplicating
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||
# - useful for reducing common configuration settings
|
||||
# - macro names are strings and must be less than 64 characters
|
||||
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
/path/to/llama-server/llama-server-ec9e0301
|
||||
--port ${PORT}
|
||||
|
||||
"default_ctx": 4096
|
||||
|
||||
# Example of macro-in-macro usage. macros can contain other macros
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
# - model settings have default values that are used if they are not defined here
|
||||
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||
# - below are examples of the all the settings a model can have
|
||||
models:
|
||||
# keys are the model names used in API requests
|
||||
"llama":
|
||||
# macros: a dictionary of string substitutions specific to this model
|
||||
# - optional, default: empty dictionary
|
||||
# - macros defined here override macros defined in the global macros section
|
||||
# - model level macros follow the same rules as global macros
|
||||
macros:
|
||||
"default_ctx": 16384
|
||||
"temp": 0.7
|
||||
|
||||
# cmd: the command to run to start the inference server.
|
||||
# - required
|
||||
# - it is just a string, similar to what you would run on the CLI
|
||||
# - using `|` allows for comments in the command, these will be parsed out
|
||||
# - macros can be used within cmd
|
||||
cmd: |
|
||||
# ${latest-llama} is a macro that is defined above
|
||||
${latest-llama}
|
||||
--model path/to/llama-8B-Q4_K_M.gguf
|
||||
--ctx-size ${default_ctx}
|
||||
--temperature ${temp}
|
||||
|
||||
# name: a display name for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
name: "llama 3.1 8B"
|
||||
|
||||
# description: a description for the model
|
||||
# - optional, default: empty string
|
||||
# - if set, it will be used in the v1/models API response
|
||||
# - if not set, it will be omitted in the JSON model record
|
||||
description: "A small but capable model used for quick testing"
|
||||
|
||||
# env: define an array of environment variables to inject into cmd's environment
|
||||
# - optional, default: empty array
|
||||
# - each value is a single string
|
||||
# - in the format: ENV_NAME=value
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||
|
||||
# proxy: the URL where llama-swap routes API requests
|
||||
# - optional, default: http://localhost:${PORT}
|
||||
# - if you used ${PORT} in cmd this can be omitted
|
||||
# - if you use a custom port in cmd this *must* be set
|
||||
proxy: http://127.0.0.1:8999
|
||||
|
||||
# aliases: alternative model names that this model configuration is used for
|
||||
# - optional, default: empty array
|
||||
# - aliases must be unique globally
|
||||
# - useful for impersonating a specific model
|
||||
aliases:
|
||||
- "gpt-4o-mini"
|
||||
- "gpt-3.5-turbo"
|
||||
|
||||
# checkEndpoint: URL path to check if the server is ready
|
||||
# - optional, default: /health
|
||||
# - endpoint is expected to return an HTTP 200 response
|
||||
# - all requests wait until the endpoint is ready or fails
|
||||
# - use "none" to skip endpoint health checking
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: 0
|
||||
# - ttl values must be a value greater than 0
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
# useModelName: override the model name that is sent to upstream server
|
||||
# - optional, default: ""
|
||||
# - useful for when the upstream server expects a specific model name that
|
||||
# is different from the model's ID
|
||||
useModelName: "qwen:qwq"
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - only stripParams is currently supported
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for server side enforcement of sampling parameters
|
||||
# - the `model` parameter can never be removed
|
||||
# - can be any JSON key in the request body
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
# - metadata is only passed through in /v1/models responses
|
||||
metadata:
|
||||
# port will remain an integer
|
||||
port: ${PORT}
|
||||
|
||||
# the ${temp} macro will remain a float
|
||||
temperature: ${temp}
|
||||
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||
|
||||
a_list:
|
||||
- 1
|
||||
- 1.23
|
||||
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||
|
||||
an_obj:
|
||||
a: "1"
|
||||
b: 2
|
||||
# objects can contain complex types with macro substitution
|
||||
# becomes: c: [0.7, false, "model: llama"]
|
||||
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||
|
||||
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||
# - optional, default: 0
|
||||
# - useful for limiting the number of active parallel requests a model can process
|
||||
# - must be set per model
|
||||
# - any number greater than 0 will override the internal default value of 10
|
||||
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||
# - recommended to be omitted and the default used
|
||||
concurrencyLimit: 0
|
||||
|
||||
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||
# - optional, default: undefined (use global setting)
|
||||
sendLoadingState: false
|
||||
|
||||
# Unlisted model example:
|
||||
"qwen-unlisted":
|
||||
# unlisted: boolean, true or false
|
||||
# - optional, default: false
|
||||
# - unlisted models do not show up in /v1/models api requests
|
||||
# - can be requested as normal through all apis
|
||||
unlisted: true
|
||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||
|
||||
# Docker example:
|
||||
# container runtimes like Docker and Podman can be used reliably with
|
||||
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||
"docker-llama":
|
||||
proxy: "http://127.0.0.1:${PORT}"
|
||||
cmd: |
|
||||
docker run --name ${MODEL_ID}
|
||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||
|
||||
# cmdStop: command to run to stop the model gracefully
|
||||
# - optional, default: ""
|
||||
# - useful for stopping commands managed by another system
|
||||
# - the upstream's process id is available in the ${PID} macro
|
||||
#
|
||||
# When empty, llama-swap has this default behaviour:
|
||||
# - on POSIX systems: a SIGTERM signal is sent
|
||||
# - on Windows, calls taskkill to stop the process
|
||||
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||
cmdStop: docker stop ${MODEL_ID}
|
||||
|
||||
# groups: a dictionary of group settings
|
||||
# - optional, default: empty dictionary
|
||||
# - provides advanced controls over model swapping behaviour
|
||||
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
||||
# - model IDs must be defined in the Models section
|
||||
# - a model can only be a member of one group
|
||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||
# - see issue #109 for details
|
||||
#
|
||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||
groups:
|
||||
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
||||
# to run a time across the whole llama-swap instance
|
||||
"group1":
|
||||
# swap: controls the model swapping behaviour in within the group
|
||||
# - optional, default: true
|
||||
# - true : only one model is allowed to run at a time
|
||||
# - false: all models can run together, no swapping
|
||||
swap: true
|
||||
|
||||
# exclusive: controls how the group affects other groups
|
||||
# - optional, default: true
|
||||
# - true: causes all other groups to unload when this group runs a model
|
||||
# - false: does not affect other groups
|
||||
exclusive: true
|
||||
|
||||
# members references the models defined above
|
||||
# required
|
||||
members:
|
||||
- "llama"
|
||||
- "qwen-unlisted"
|
||||
|
||||
# Example:
|
||||
# - in group2 all models can run at the same time
|
||||
# - when a different group is loaded it causes all running models in this group to unload
|
||||
"group2":
|
||||
swap: false
|
||||
|
||||
# exclusive: false does not unload other groups when a model in group2 is requested
|
||||
# - the models in group2 will be loaded but will not unload any other groups
|
||||
exclusive: false
|
||||
members:
|
||||
- "docker-llama"
|
||||
- "modelA"
|
||||
- "modelB"
|
||||
|
||||
# Example:
|
||||
# - a persistent group, prevents other groups from unloading it
|
||||
"forever":
|
||||
# persistent: prevents over groups from unloading the models in this group
|
||||
# - optional, default: false
|
||||
# - does not affect individual model behaviour
|
||||
persistent: true
|
||||
|
||||
# set swap/exclusive to false to prevent swapping inside the group
|
||||
# and the unloading of other groups
|
||||
swap: false
|
||||
exclusive: false
|
||||
members:
|
||||
- "forever-modelA"
|
||||
- "forever-modelB"
|
||||
- "forever-modelc"
|
||||
|
||||
# hooks: a dictionary of event triggers and actions
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported hook is on_startup
|
||||
hooks:
|
||||
# on_startup: a dictionary of actions to perform on startup
|
||||
# - optional, default: empty dictionary
|
||||
# - the only supported action is preload
|
||||
on_startup:
|
||||
# preload: a list of model ids to load on startup
|
||||
# - optional, default: empty list
|
||||
# - model names must match keys in the models sections
|
||||
# - when preloading multiple models at once, define a group
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
apiKey: sk-your-openrouter-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
```
|
||||
@@ -0,0 +1,9 @@
|
||||
## Container Security
|
||||
|
||||
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
|
||||
|
||||
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
|
||||
|
||||
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
|
||||
|
||||
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
|
||||
@@ -0,0 +1,153 @@
|
||||
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
|
||||
|
||||
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
|
||||
|
||||
## Here's what you you need:
|
||||
|
||||
- aider - [installation docs](https://aider.chat/docs/install.html)
|
||||
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
|
||||
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
|
||||
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
|
||||
- 24GB VRAM video card
|
||||
|
||||
## Running aider
|
||||
|
||||
The goal is getting this command line to work:
|
||||
|
||||
```sh
|
||||
aider --architect \
|
||||
--no-show-model-warnings \
|
||||
--model openai/QwQ \
|
||||
--editor-model openai/qwen-coder-32B \
|
||||
--model-settings-file aider.model.settings.yml \
|
||||
--openai-api-key "sk-na" \
|
||||
--openai-api-base "http://10.0.1.24:8080/v1" \
|
||||
```
|
||||
|
||||
Set `--openai-api-base` to the IP and port where your llama-swap is running.
|
||||
|
||||
## Create an aider model settings file
|
||||
|
||||
```yaml
|
||||
# aider.model.settings.yml
|
||||
|
||||
#
|
||||
# !!! important: model names must match llama-swap configuration names !!!
|
||||
#
|
||||
|
||||
- name: "openai/QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/qwen-coder-32B"
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
- name: "openai/qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
```
|
||||
|
||||
## llama-swap configuration
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||
--ctx-size 16000
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
-ngl 99
|
||||
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
|
||||
"QwQ":
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
--ctx-size 32000
|
||||
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
|
||||
--min-p 0.01 --top-k 40 --top-p 0.95
|
||||
-ngl 99
|
||||
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||
```
|
||||
|
||||
## Advanced, Dual GPU Configuration
|
||||
|
||||
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
|
||||
|
||||
In llama-swap's configuration file:
|
||||
|
||||
1. add a `profiles` section with `aider` as the profile name
|
||||
2. using the `env` field to specify the GPU IDs for each model
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
# Add a profile for aider
|
||||
profiles:
|
||||
aider:
|
||||
- qwen-coder-32B
|
||||
- QwQ
|
||||
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
# manually set the GPU to run on
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
cmd: /path/to/llama-server ...
|
||||
|
||||
"QwQ":
|
||||
# manually set the GPU to run on
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=1"
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
cmd: /path/to/llama-server ...
|
||||
```
|
||||
|
||||
Append the profile tag, `aider:`, to the model names in the model settings file
|
||||
|
||||
```yaml
|
||||
# aider.model.settings.yml
|
||||
- name: "openai/aider:QwQ"
|
||||
weak_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
|
||||
- name: "openai/aider:qwen-coder-32B"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||
```
|
||||
|
||||
Run aider with:
|
||||
|
||||
```sh
|
||||
$ aider --architect \
|
||||
--no-show-model-warnings \
|
||||
--model openai/aider:QwQ \
|
||||
--editor-model openai/aider:qwen-coder-32B \
|
||||
--config aider.conf.yml \
|
||||
--model-settings-file aider.model.settings.yml
|
||||
--openai-api-key "sk-na" \
|
||||
--openai-api-base "http://10.0.1.24:8080/v1"
|
||||
```
|
||||
@@ -0,0 +1,28 @@
|
||||
# this makes use of llama-swap's profile feature to
|
||||
# keep the architect and editor models in VRAM on different GPUs
|
||||
|
||||
- name: "openai/aider:QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/aider:qwen-coder-32B"
|
||||
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||
|
||||
- name: "openai/aider:qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||
@@ -0,0 +1,26 @@
|
||||
- name: "openai/QwQ"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.95
|
||||
top_k: 40
|
||||
presence_penalty: 0.1
|
||||
repetition_penalty: 1
|
||||
num_ctx: 16384
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
weak_model_name: "openai/qwen-coder-32B"
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
- name: "openai/qwen-coder-32B"
|
||||
edit_format: diff
|
||||
extra_params:
|
||||
max_tokens: 16384
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
repetition_penalty: 1.05
|
||||
use_temperature: 0.6
|
||||
reasoning_tag: think
|
||||
editor_edit_format: editor-diff
|
||||
editor_model_name: "openai/qwen-coder-32B"
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
healthCheckTimeout: 300
|
||||
logLevel: debug
|
||||
|
||||
profiles:
|
||||
aider:
|
||||
- qwen-coder-32B
|
||||
- QwQ
|
||||
|
||||
models:
|
||||
"qwen-coder-32B":
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=0"
|
||||
aliases:
|
||||
- coder
|
||||
proxy: "http://127.0.0.1:8999"
|
||||
|
||||
# set appropriate paths for your environment
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||
--ctx-size 16000
|
||||
--ctx-size-draft 16000
|
||||
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
|
||||
-ngl 99 -ngld 99
|
||||
--draft-max 16 --draft-min 4 --draft-p-min 0.4
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
"QwQ":
|
||||
env:
|
||||
- "CUDA_VISIBLE_DEVICES=1"
|
||||
proxy: "http://127.0.0.1:9503"
|
||||
|
||||
# set appropriate paths for your environment
|
||||
cmd: >
|
||||
/path/to/llama-server
|
||||
--host 127.0.0.1 --port 9503
|
||||
--flash-attn --metrics
|
||||
--slots
|
||||
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||
--cache-type-k q8_0 --cache-type-v q8_0
|
||||
--ctx-size 32000
|
||||
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||
--temp 0.6
|
||||
--repeat-penalty 1.1
|
||||
--dry-multiplier 0.5
|
||||
--min-p 0.01
|
||||
--top-k 40
|
||||
--top-p 0.95
|
||||
-ngl 99 -ngld 99
|
||||
@@ -0,0 +1,51 @@
|
||||
# Restart llama-swap on config change
|
||||
|
||||
Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change.
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#
|
||||
# A simple watch and restart llama-swap when its configuration
|
||||
# file changes. Useful for trying out configuration changes
|
||||
# without manually restarting the server each time.
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <path to config.yaml>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
while true; do
|
||||
# Start the process again
|
||||
./llama-swap-linux-amd64 -config $1 -listen :1867 &
|
||||
PID=$!
|
||||
echo "Started llama-swap with PID $PID"
|
||||
|
||||
# Wait for modifications in the specified directory or file
|
||||
inotifywait -e modify "$1"
|
||||
|
||||
# Check if process exists before sending signal
|
||||
if kill -0 $PID 2>/dev/null; then
|
||||
echo "Sending SIGTERM to $PID"
|
||||
kill -SIGTERM $PID
|
||||
wait $PID
|
||||
else
|
||||
echo "Process $PID no longer exists"
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
```
|
||||
|
||||
## Usage and output example
|
||||
|
||||
```bash
|
||||
$ ./watch-and-restart.sh config.yaml
|
||||
Started llama-swap with PID 495455
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
Sending SIGTERM to 495455
|
||||
Shutting down llama-swap
|
||||
Started llama-swap with PID 495486
|
||||
Setting up watches.
|
||||
Watches established.
|
||||
llama-swap listening on :1867
|
||||
```
|
||||
@@ -0,0 +1,3 @@
|
||||
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
|
||||
|
||||
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
|
||||
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Default initializes a default in-process dispatcher
|
||||
var Default = NewDispatcherConfig(25000)
|
||||
|
||||
// On subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work. This
|
||||
// functions same way as Subscribe() but uses the default dispatcher instead.
|
||||
func On[T Event](handler func(T)) context.CancelFunc {
|
||||
return Subscribe(Default, handler)
|
||||
}
|
||||
|
||||
// OnType subscribes to an event with the specified event type. This functions
|
||||
// same way as SubscribeTo() but uses the default dispatcher instead.
|
||||
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
|
||||
return SubscribeTo(Default, eventType, handler)
|
||||
}
|
||||
|
||||
// Emit writes an event into the dispatcher. This functions same way as
|
||||
// Publish() but uses the default dispatcher instead.
|
||||
func Emit[T Event](ev T) {
|
||||
Publish(Default, ev)
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
/*
|
||||
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
|
||||
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
|
||||
*/
|
||||
func BenchmarkSubscribeConcurrent(b *testing.B) {
|
||||
d := NewDispatcher()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
unsub := Subscribe(d, func(ev MyEvent1) {})
|
||||
unsub()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultPublish(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe
|
||||
var count int64
|
||||
defer On(func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
defer OnType(TypeEvent1, func(ev MyEvent1) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(4)
|
||||
Emit(MyEvent1{})
|
||||
Emit(MyEvent1{})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(4), count)
|
||||
}
|
||||
+324
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for details.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Event represents an event contract
|
||||
type Event interface {
|
||||
Type() uint32
|
||||
}
|
||||
|
||||
// registry holds an immutable sorted array of event mappings
|
||||
type registry struct {
|
||||
keys []uint32 // Event types (sorted)
|
||||
grps []any // Corresponding subscribers
|
||||
}
|
||||
|
||||
// ------------------------------------- Dispatcher -------------------------------------
|
||||
|
||||
// Dispatcher represents an event dispatcher.
|
||||
type Dispatcher struct {
|
||||
subs atomic.Pointer[registry] // Atomic pointer to immutable array
|
||||
done chan struct{} // Cancellation
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
|
||||
}
|
||||
|
||||
// NewDispatcher creates a new dispatcher of events.
|
||||
func NewDispatcher() *Dispatcher {
|
||||
return NewDispatcherConfig(50000)
|
||||
}
|
||||
|
||||
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
|
||||
func NewDispatcherConfig(maxQueue int) *Dispatcher {
|
||||
d := &Dispatcher{
|
||||
done: make(chan struct{}),
|
||||
maxQueue: maxQueue,
|
||||
}
|
||||
|
||||
d.subs.Store(®istry{
|
||||
keys: make([]uint32, 0, 16),
|
||||
grps: make([]any, 0, 16),
|
||||
})
|
||||
return d
|
||||
}
|
||||
|
||||
// Close closes the dispatcher
|
||||
func (d *Dispatcher) Close() error {
|
||||
close(d.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isClosed returns whether the dispatcher is closed or not
|
||||
func (d *Dispatcher) isClosed() bool {
|
||||
select {
|
||||
case <-d.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// findGroup performs a lock-free binary search for the event type
|
||||
func (d *Dispatcher) findGroup(eventType uint32) any {
|
||||
reg := d.subs.Load()
|
||||
keys := reg.keys
|
||||
|
||||
// Inlined binary search for better cache locality
|
||||
left, right := 0, len(keys)
|
||||
for left < right {
|
||||
mid := left + (right-left)/2
|
||||
if keys[mid] < eventType {
|
||||
left = mid + 1
|
||||
} else {
|
||||
right = mid
|
||||
}
|
||||
}
|
||||
|
||||
if left < len(keys) && keys[left] == eventType {
|
||||
return reg.grps[left]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe subscribes to an event, the type of the event will be automatically
|
||||
// inferred from the provided type. Must be constant for this to work.
|
||||
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
|
||||
var event T
|
||||
return SubscribeTo(broker, event.Type(), handler)
|
||||
}
|
||||
|
||||
// SubscribeTo subscribes to an event with the specified event type.
|
||||
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
|
||||
if broker.isClosed() {
|
||||
panic(errClosed)
|
||||
}
|
||||
|
||||
broker.mu.Lock()
|
||||
defer broker.mu.Unlock()
|
||||
|
||||
// Check if group already exists
|
||||
if existing := broker.findGroup(eventType); existing != nil {
|
||||
grp := groupOf[T](eventType, existing)
|
||||
sub := grp.Add(handler)
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new group
|
||||
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
|
||||
sub := grp.Add(handler)
|
||||
|
||||
// Copy-on-write: insert new entry in sorted position
|
||||
old := broker.subs.Load()
|
||||
idx := sort.Search(len(old.keys), func(i int) bool {
|
||||
return old.keys[i] >= eventType
|
||||
})
|
||||
|
||||
// Create new arrays with space for one more element
|
||||
newKeys := make([]uint32, len(old.keys)+1)
|
||||
newGrps := make([]any, len(old.grps)+1)
|
||||
|
||||
// Copy elements before insertion point
|
||||
copy(newKeys[:idx], old.keys[:idx])
|
||||
copy(newGrps[:idx], old.grps[:idx])
|
||||
|
||||
// Insert new element
|
||||
newKeys[idx] = eventType
|
||||
newGrps[idx] = grp
|
||||
|
||||
// Copy elements after insertion point
|
||||
copy(newKeys[idx+1:], old.keys[idx:])
|
||||
copy(newGrps[idx+1:], old.grps[idx:])
|
||||
|
||||
// Atomically store the new registry (mutex ensures no concurrent writers)
|
||||
newReg := ®istry{keys: newKeys, grps: newGrps}
|
||||
broker.subs.Store(newReg)
|
||||
|
||||
return func() {
|
||||
grp.Del(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish writes an event into the dispatcher
|
||||
func Publish[T Event](broker *Dispatcher, ev T) {
|
||||
eventType := ev.Type()
|
||||
if sub := broker.findGroup(eventType); sub != nil {
|
||||
group := groupOf[T](eventType, sub)
|
||||
group.Broadcast(ev)
|
||||
}
|
||||
}
|
||||
|
||||
// Count counts the number of subscribers, this is for testing only.
|
||||
func (d *Dispatcher) count(eventType uint32) int {
|
||||
if group := d.findGroup(eventType); group != nil {
|
||||
return group.(interface{ Count() int }).Count()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// groupOf casts the subscriber group to the specified generic type
|
||||
func groupOf[T Event](eventType uint32, subs any) *group[T] {
|
||||
if group, ok := subs.(*group[T]); ok {
|
||||
return group
|
||||
}
|
||||
|
||||
panic(errConflict[T](eventType, subs))
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber -------------------------------------
|
||||
|
||||
// consumer represents a consumer with a message queue
|
||||
type consumer[T Event] struct {
|
||||
queue []T // Current work queue
|
||||
stop bool // Stop signal
|
||||
}
|
||||
|
||||
// Listen listens to the event queue and processes events
|
||||
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
|
||||
pending := make([]T, 0, 128)
|
||||
|
||||
for {
|
||||
c.L.Lock()
|
||||
for len(s.queue) == 0 {
|
||||
switch {
|
||||
case s.stop:
|
||||
c.L.Unlock()
|
||||
return
|
||||
default:
|
||||
c.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// Swap buffers and reset the current queue
|
||||
temp := s.queue
|
||||
s.queue = pending[:0]
|
||||
pending = temp
|
||||
c.L.Unlock()
|
||||
|
||||
// Outside of the critical section, process the work
|
||||
for _, event := range pending {
|
||||
fn(event)
|
||||
}
|
||||
|
||||
// Notify potential publishers waiting due to backpressure
|
||||
c.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Subscriber Group -------------------------------------
|
||||
|
||||
// group represents a consumer group
|
||||
type group[T Event] struct {
|
||||
cond *sync.Cond
|
||||
subs []*consumer[T]
|
||||
maxQueue int // Maximum queue size per consumer
|
||||
maxLen int // Current maximum queue length across all consumers
|
||||
}
|
||||
|
||||
// Broadcast sends an event to all consumers
|
||||
func (s *group[T]) Broadcast(ev T) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Calculate current maximum queue length
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
|
||||
// Backpressure: wait if queues are full
|
||||
for s.maxLen >= s.maxQueue {
|
||||
s.cond.Wait()
|
||||
|
||||
// Recalculate after wakeup
|
||||
s.maxLen = 0
|
||||
for _, sub := range s.subs {
|
||||
if len(sub.queue) > s.maxLen {
|
||||
s.maxLen = len(sub.queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add event to all queues and track new maximum
|
||||
newMax := 0
|
||||
for _, sub := range s.subs {
|
||||
sub.queue = append(sub.queue, ev)
|
||||
if len(sub.queue) > newMax {
|
||||
newMax = len(sub.queue)
|
||||
}
|
||||
}
|
||||
s.maxLen = newMax
|
||||
s.cond.Broadcast() // Wake consumers
|
||||
}
|
||||
|
||||
// Add adds a subscriber to the list
|
||||
func (s *group[T]) Add(handler func(T)) *consumer[T] {
|
||||
sub := &consumer[T]{
|
||||
queue: make([]T, 0, 64),
|
||||
}
|
||||
|
||||
// Add the consumer to the list of active consumers
|
||||
s.cond.L.Lock()
|
||||
s.subs = append(s.subs, sub)
|
||||
s.cond.L.Unlock()
|
||||
|
||||
// Start listening
|
||||
go sub.Listen(s.cond, handler)
|
||||
return sub
|
||||
}
|
||||
|
||||
// Del removes a subscriber from the list
|
||||
func (s *group[T]) Del(sub *consumer[T]) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
|
||||
// Search and remove the subscriber
|
||||
sub.stop = true
|
||||
for i, v := range s.subs {
|
||||
if v == sub {
|
||||
copy(s.subs[i:], s.subs[i+1:])
|
||||
s.subs = s.subs[:len(s.subs)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------- Debugging -------------------------------------
|
||||
|
||||
var errClosed = fmt.Errorf("event dispatcher is closed")
|
||||
|
||||
// Count returns the number of subscribers in this group
|
||||
func (s *group[T]) Count() int {
|
||||
return len(s.subs)
|
||||
}
|
||||
|
||||
// String returns string representation of the type
|
||||
func (s *group[T]) String() string {
|
||||
typ := reflect.TypeOf(s).String()
|
||||
idx := strings.LastIndex(typ, "/")
|
||||
typ = typ[idx+1 : len(typ)-1]
|
||||
return typ
|
||||
}
|
||||
|
||||
// errConflict returns a conflict message
|
||||
func errConflict[T any](eventType uint32, existing any) string {
|
||||
var want T
|
||||
return fmt.Sprintf(
|
||||
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
|
||||
want, existing, eventType,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Subscribe, must be received in order
|
||||
var count int64
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
// Publish
|
||||
wg.Add(3)
|
||||
Publish(d, MyEvent1{Number: 1})
|
||||
Publish(d, MyEvent1{Number: 2})
|
||||
Publish(d, MyEvent1{Number: 3})
|
||||
|
||||
// Wait and check
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
unsubscribe := Subscribe(d, func(ev MyEvent1) {
|
||||
// Nothing
|
||||
})
|
||||
|
||||
assert.Equal(t, 1, d.count(TypeEvent1))
|
||||
unsubscribe()
|
||||
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||
}
|
||||
|
||||
func TestConcurrent(t *testing.T) {
|
||||
const max = 1000000
|
||||
var count int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
d := NewDispatcher()
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
if current := atomic.AddInt64(&count, 1); current == max {
|
||||
wg.Done()
|
||||
}
|
||||
})()
|
||||
|
||||
// Asynchronously publish
|
||||
go func() {
|
||||
for i := 0; i < max; i++ {
|
||||
Publish(d, MyEvent1{})
|
||||
}
|
||||
}()
|
||||
|
||||
defer Subscribe(d, func(ev MyEvent1) {
|
||||
// Subscriber that does nothing
|
||||
})()
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, max, int(count))
|
||||
}
|
||||
|
||||
func TestSubscribeDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPublishDifferentType(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
Publish(d, MyEvent1{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseDispatcher(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
|
||||
|
||||
assert.NoError(t, d.Close())
|
||||
assert.Panics(t, func() {
|
||||
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatrix(t *testing.T) {
|
||||
const amount = 1000
|
||||
for _, subs := range []int{1, 10, 100} {
|
||||
for _, topics := range []int{1, 10} {
|
||||
expected := subs * topics * amount
|
||||
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
|
||||
var count atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(expected)
|
||||
|
||||
d := NewDispatcher()
|
||||
for i := 0; i < subs; i++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
|
||||
count.Add(1)
|
||||
wg.Done()
|
||||
})()
|
||||
}
|
||||
}
|
||||
|
||||
for n := 0; n < amount; n++ {
|
||||
for id := 0; id < topics; id++ {
|
||||
go Publish(d, MyEvent3{ID: id})
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, expected, int(count.Load()))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentSubscriptionRace(t *testing.T) {
|
||||
// This test specifically targets the race condition that occurs when multiple
|
||||
// goroutines try to subscribe to different event types simultaneously.
|
||||
// Without the CAS loop, subscriptions could be lost due to registry corruption.
|
||||
|
||||
const numGoroutines = 100
|
||||
const numEventTypes = 50
|
||||
|
||||
d := NewDispatcher()
|
||||
defer d.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var receivedCount int64
|
||||
var subscribedTypes sync.Map // Thread-safe map
|
||||
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Start multiple goroutines that subscribe to different event types concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine subscribes to a unique event type
|
||||
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
|
||||
|
||||
// Subscribe to the event type
|
||||
SubscribeTo(d, eventType, func(ev MyEvent3) {
|
||||
atomic.AddInt64(&receivedCount, 1)
|
||||
})
|
||||
|
||||
// Record that this type was subscribed
|
||||
subscribedTypes.Store(eventType, true)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all subscriptions to complete
|
||||
wg.Wait()
|
||||
|
||||
// Count the number of unique event types subscribed
|
||||
expectedTypes := 0
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
expectedTypes++
|
||||
return true
|
||||
})
|
||||
|
||||
// Small delay to ensure all subscriptions are fully processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Publish events to each subscribed type
|
||||
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||
eventType := key.(uint32)
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
return true
|
||||
})
|
||||
|
||||
// Wait for all events to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify that we received at least the expected number of events
|
||||
// (there might be more if multiple goroutines subscribed to the same event type)
|
||||
received := atomic.LoadInt64(&receivedCount)
|
||||
assert.GreaterOrEqual(t, int(received), expectedTypes,
|
||||
"Should have received at least %d events, got %d", expectedTypes, received)
|
||||
|
||||
// Verify that we have the expected number of unique event types
|
||||
assert.Equal(t, numEventTypes, expectedTypes,
|
||||
"Should have exactly %d unique event types", numEventTypes)
|
||||
}
|
||||
|
||||
func TestConcurrentHandlerRegistration(t *testing.T) {
|
||||
const numGoroutines = 100
|
||||
|
||||
// Test concurrent subscriptions to the same event type
|
||||
t.Run("SameEventType", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var handlerCount int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple goroutines subscribing to the same event type (0x1)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
|
||||
atomic.AddInt64(&handlerCount, 1)
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all handlers were registered by publishing an event
|
||||
atomic.StoreInt64(&handlerCount, 0)
|
||||
Publish(d, MyEvent1{})
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
|
||||
"Not all handlers were registered due to race condition")
|
||||
})
|
||||
|
||||
// Test concurrent subscriptions to different event types
|
||||
t.Run("DifferentEventTypes", func(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
var wg sync.WaitGroup
|
||||
receivedEvents := make(map[uint32]*int64)
|
||||
|
||||
// Create multiple event types and subscribe concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
eventType := uint32(100 + i)
|
||||
counter := new(int64)
|
||||
receivedEvents[eventType] = counter
|
||||
|
||||
wg.Add(1)
|
||||
go func(et uint32, cnt *int64) {
|
||||
defer wg.Done()
|
||||
SubscribeTo(d, et, func(ev MyEvent3) {
|
||||
atomic.AddInt64(cnt, 1)
|
||||
})
|
||||
}(eventType, counter)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Publish events to all types
|
||||
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
|
||||
Publish(d, MyEvent3{ID: int(eventType)})
|
||||
}
|
||||
|
||||
// Small delay to ensure all handlers have executed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify all event types received their events
|
||||
for eventType, counter := range receivedEvents {
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
|
||||
"Event type %d did not receive its event", eventType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackpressure(t *testing.T) {
|
||||
d := NewDispatcher()
|
||||
d.maxQueue = 10
|
||||
|
||||
var processedCount int64
|
||||
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
|
||||
atomic.AddInt64(&processedCount, 1)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
const eventsToPublish = 1000
|
||||
for i := 0; i < eventsToPublish; i++ {
|
||||
Publish(d, MyEvent3{ID: 0x200})
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all events were eventually processed
|
||||
finalProcessed := atomic.LoadInt64(&processedCount)
|
||||
assert.Equal(t, int64(eventsToPublish), finalProcessed)
|
||||
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
|
||||
}
|
||||
|
||||
// ------------------------------------- Test Events -------------------------------------
|
||||
|
||||
const (
|
||||
TypeEvent1 = 0x1
|
||||
TypeEvent2 = 0x2
|
||||
)
|
||||
|
||||
type MyEvent1 struct {
|
||||
Number int
|
||||
}
|
||||
|
||||
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
|
||||
|
||||
type MyEvent2 struct {
|
||||
Text string
|
||||
}
|
||||
|
||||
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
|
||||
|
||||
type MyEvent3 struct {
|
||||
ID int
|
||||
}
|
||||
|
||||
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
|
||||
@@ -1,9 +1,14 @@
|
||||
module github.com/mostlygeek/llama-swap
|
||||
|
||||
go 1.23.0
|
||||
go 1.25.4
|
||||
|
||||
require (
|
||||
github.com/billziss-gh/golib v0.2.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -15,12 +20,10 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.10.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
@@ -29,12 +32,14 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
@@ -9,12 +11,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
@@ -23,9 +29,9 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
@@ -57,6 +63,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
@@ -64,24 +80,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
+181
-9
@@ -1,23 +1,38 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/mostlygeek/llama-swap/proxy"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
var version string = "0"
|
||||
var commit string = "abcd1234"
|
||||
var date = "unknown"
|
||||
var (
|
||||
version string = "0"
|
||||
commit string = "abcd1234"
|
||||
date string = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Define a command-line flag for the port
|
||||
configPath := flag.String("config", "config.yaml", "config file name")
|
||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||
listenStr := flag.String("listen", "", "listen ip/port")
|
||||
certFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
keyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
showVersion := flag.Bool("version", false, "show version of build")
|
||||
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
@@ -26,22 +41,179 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
config, err := proxy.LoadConfig(*configPath)
|
||||
conf, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(conf.Profiles) > 0 {
|
||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||
}
|
||||
|
||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||
gin.SetMode(mode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
proxyManager := proxy.New(config)
|
||||
fmt.Println("llama-swap listening on " + *listenStr)
|
||||
if err := proxyManager.Run(*listenStr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
// Validate TLS flags.
|
||||
var useTLS = (*certFile != "" && *keyFile != "")
|
||||
if (*certFile != "" && *keyFile == "") ||
|
||||
(*certFile == "" && *keyFile != "") {
|
||||
fmt.Println("Error: Both --tls-cert-file and --tls-key-file must be provided for TLS.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default ports.
|
||||
if *listenStr == "" {
|
||||
defaultPort := ":8080"
|
||||
if useTLS {
|
||||
defaultPort = ":8443"
|
||||
}
|
||||
listenStr = &defaultPort
|
||||
}
|
||||
|
||||
// Setup channels for server management
|
||||
exitChan := make(chan struct{})
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Create server with initial handler
|
||||
srv := &http.Server{
|
||||
Addr: *listenStr,
|
||||
}
|
||||
|
||||
// Support for watching config and reloading when it changes
|
||||
reloadProxyManager := func() {
|
||||
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Configuration Changed")
|
||||
currentPM.Shutdown()
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
srv.Handler = newPM
|
||||
fmt.Println("Configuration Reloaded")
|
||||
|
||||
// wait a few seconds and tell any UI to reload
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateEnd,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
conf, err = config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
newPM := proxy.New(conf)
|
||||
newPM.SetVersion(date, commit, version)
|
||||
srv.Handler = newPM
|
||||
}
|
||||
}
|
||||
|
||||
// load the initial proxy manager
|
||||
reloadProxyManager()
|
||||
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||
if *watchConfig {
|
||||
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||
debouncedReload()
|
||||
}
|
||||
})()
|
||||
|
||||
fmt.Println("Watching Configuration for changes")
|
||||
go func() {
|
||||
absConfigPath, err := filepath.Abs(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||
return
|
||||
}
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
configDir := filepath.Dir(absConfigPath)
|
||||
err = watcher.Add(configDir)
|
||||
if err != nil {
|
||||
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
||||
return
|
||||
}
|
||||
|
||||
defer watcher.Close()
|
||||
for {
|
||||
select {
|
||||
case changeEvent := <-watcher.Events:
|
||||
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
||||
// the change for k8s configmap
|
||||
event.Emit(proxy.ConfigFileChangedEvent{
|
||||
ReloadingState: proxy.ReloadingStateStart,
|
||||
})
|
||||
}
|
||||
|
||||
case err := <-watcher.Errors:
|
||||
log.Printf("File watcher error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// shutdown on signal
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||
pm.Shutdown()
|
||||
} else {
|
||||
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
fmt.Printf("Server shutdown error: %v\n", err)
|
||||
}
|
||||
close(exitChan)
|
||||
}()
|
||||
|
||||
// Start server
|
||||
go func() {
|
||||
var err error
|
||||
if useTLS {
|
||||
fmt.Printf("llama-swap listening with TLS on https://%s\n", *listenStr)
|
||||
err = srv.ListenAndServeTLS(*certFile, *keyFile)
|
||||
} else {
|
||||
fmt.Printf("llama-swap listening on http://%s\n", *listenStr)
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Fatal server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for exit signal
|
||||
<-exitChan
|
||||
}
|
||||
|
||||
func debounce(interval time.Duration, f func()) func() {
|
||||
var timer *time.Timer
|
||||
return func() {
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.AfterFunc(interval, f)
|
||||
}
|
||||
}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 51 KiB |
@@ -1,139 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
// Define a command-line flag for the port
|
||||
port := flag.String("port", "8080", "port to listen on")
|
||||
|
||||
// Define a command-line flag for the response message
|
||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||
|
||||
silent := flag.Bool("silent", false, "disable all logging")
|
||||
|
||||
flag.Parse() // Parse the command-line flags
|
||||
|
||||
// Create a new Gin router
|
||||
r := gin.New()
|
||||
|
||||
// Set up the handler function using the provided response message
|
||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
|
||||
// add a wait to simulate a slow query
|
||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.POST("/v1/completions", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/slow-respond", func(c *gin.Context) {
|
||||
echo := c.Query("echo")
|
||||
delay := c.Query("delay")
|
||||
|
||||
if echo == "" {
|
||||
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
}
|
||||
|
||||
// Parse the duration
|
||||
if delay == "" {
|
||||
delay = "100ms"
|
||||
}
|
||||
|
||||
t, err := time.ParseDuration(delay)
|
||||
if err != nil {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/plain")
|
||||
for _, char := range echo {
|
||||
c.Writer.Write([]byte(string(char)))
|
||||
c.Writer.Flush()
|
||||
|
||||
// wait
|
||||
<-time.After(t)
|
||||
}
|
||||
})
|
||||
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
})
|
||||
|
||||
r.GET("/env", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, *responseMessage)
|
||||
|
||||
// Get environment variables
|
||||
envVars := os.Environ()
|
||||
|
||||
// Write each environment variable to the response
|
||||
for _, envVar := range envVars {
|
||||
c.String(200, envVar)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up the /health endpoint handler function
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||
})
|
||||
|
||||
address := "127.0.0.1:" + *port // Address with the specified port
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: address,
|
||||
Handler: r.Handler(),
|
||||
}
|
||||
|
||||
// Disable logging if the --silent flag is set
|
||||
if *silent {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("simple-responder listening on %s\n", address)
|
||||
// service connections
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("simple-responder err: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal to gracefully shutdown the server with
|
||||
// a timeout of 5 seconds.
|
||||
quit := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("simple-responder shutting down")
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
ui_dist/*
|
||||
@@ -1,98 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/google/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
Models map[string]ModelConfig `yaml:"models"`
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config Config
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
// Remove trailing backslashes
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
|
||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
||||
|
||||
// Split the command into arguments
|
||||
args, err := shlex.Split(cmdStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
@@ -0,0 +1,780 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
const (
|
||||
LogToStdoutProxy = "proxy"
|
||||
LogToStdoutUpstream = "upstream"
|
||||
LogToStdoutBoth = "both"
|
||||
LogToStdoutNone = "none"
|
||||
)
|
||||
|
||||
type MacroEntry struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
type MacroList []MacroEntry
|
||||
|
||||
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||
if value.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("macros must be a mapping")
|
||||
}
|
||||
|
||||
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||
for i := 0; i < len(value.Content); i += 2 {
|
||||
keyNode := value.Content[i]
|
||||
valueNode := value.Content[i+1]
|
||||
|
||||
var name string
|
||||
if err := keyNode.Decode(&name); err != nil {
|
||||
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||
}
|
||||
|
||||
var val any
|
||||
if err := valueNode.Decode(&val); err != nil {
|
||||
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||
}
|
||||
|
||||
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||
}
|
||||
|
||||
*ml = entries
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a macro value by name
|
||||
func (ml MacroList) Get(name string) (any, bool) {
|
||||
for _, entry := range ml {
|
||||
if entry.Name == name {
|
||||
return entry.Value, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||
func (ml MacroList) ToMap() map[string]any {
|
||||
result := make(map[string]any, len(ml))
|
||||
for _, entry := range ml {
|
||||
result[entry.Name] = entry.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
Exclusive bool `yaml:"exclusive"`
|
||||
Persistent bool `yaml:"persistent"`
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
defaults := rawGroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Persistent: false,
|
||||
Members: []string{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GroupConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||
}
|
||||
|
||||
type HookOnStartup struct {
|
||||
Preload []string `yaml:"preload"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
LogTimeFormat string `yaml:"logTimeFormat"`
|
||||
LogToStdout string `yaml:"logToStdout"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
CaptureBuffer int `yaml:"captureBuffer"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
|
||||
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// map aliases to actual model IDs
|
||||
aliases map[string]string
|
||||
|
||||
// automatic port assignments
|
||||
StartPort int `yaml:"startPort"`
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
|
||||
// send loading state in reasoning
|
||||
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
|
||||
// support API keys, see issue #433, #50, #251
|
||||
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
if _, found := c.Models[search]; found {
|
||||
return search, true
|
||||
} else if name, found := c.aliases[search]; found {
|
||||
return name, found
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
if realName, found := c.RealModelName(modelName); !found {
|
||||
return ModelConfig{}, "", false
|
||||
} else {
|
||||
return c.Models[realName], realName, true
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
if config.Groups == nil {
|
||||
config.Groups = make(map[string]GroupConfig)
|
||||
}
|
||||
|
||||
defaultGroup := GroupConfig{
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{},
|
||||
}
|
||||
// if groups is empty, create a default group and put
|
||||
// all models into it
|
||||
if len(config.Groups) == 0 {
|
||||
for modelName := range config.Models {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
for _, groupConfig := range config.Groups {
|
||||
for _, member := range groupConfig.Members {
|
||||
if member == modelName {
|
||||
foundModel = true
|
||||
break found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundModel {
|
||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if len(v) >= 1024 {
|
||||
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||
}
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
//go:build !windows
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
# comment 1
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
|
||||
// Test the default values are automatically set for global, model and group configurations
|
||||
// after loading the configuration
|
||||
func TestConfig_DefaultValuesPosix(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
assert.True(t, exists, "default group should exist")
|
||||
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||
assert.Equal(t, true, defaultGroup.Swap)
|
||||
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||
assert.Equal(t, false, defaultGroup.Persistent)
|
||||
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||
}
|
||||
|
||||
model1, exists := config.Models["model1"]
|
||||
assert.True(t, exists, "model1 should exist")
|
||||
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||
assert.Equal(t, "", model1.CmdStop)
|
||||
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||
assert.Equal(t, []string{}, model1.Aliases)
|
||||
assert.Equal(t, []string{}, model1.Env)
|
||||
assert.Equal(t, 0, model1.UnloadAfter)
|
||||
assert.Equal(t, false, model1.Unlisted)
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
|
||||
// default empty filter exists
|
||||
assert.Equal(t, "", model1.Filters.StripParams)
|
||||
}
|
||||
|
||||
func TestConfig_LoadPosix(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
hooks:
|
||||
on_startup:
|
||||
preload: ["model1", "model2"]
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
name: "Model 1"
|
||||
description: "This is model 1"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
Hooks: HooksConfig{
|
||||
OnStartup: HookOnStartup{
|
||||
Preload: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,242 @@
|
||||
//go:build windows
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
// does not support single quoted strings like in config_posix_test.go
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
|
||||
-a "double quotes" \
|
||||
-s
|
||||
--arg3 123 \
|
||||
|
||||
# comment 2
|
||||
--arg4 '"string in string"'
|
||||
|
||||
|
||||
|
||||
# this will get stripped out as well as the white space above
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
|
||||
func TestConfig_DefaultValuesWindows(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||
assert.Equal(t, 5800, config.StartPort)
|
||||
assert.Equal(t, "info", config.LogLevel)
|
||||
assert.Equal(t, "", config.LogTimeFormat)
|
||||
|
||||
// Test default group exists
|
||||
defaultGroup, exists := config.Groups["(default)"]
|
||||
assert.True(t, exists, "default group should exist")
|
||||
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||
assert.Equal(t, true, defaultGroup.Swap)
|
||||
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||
assert.Equal(t, false, defaultGroup.Persistent)
|
||||
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||
}
|
||||
|
||||
model1, exists := config.Models["model1"]
|
||||
assert.True(t, exists, "model1 should exist")
|
||||
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||
assert.Equal(t, "taskkill /f /t /pid ${PID}", model1.CmdStop)
|
||||
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||
assert.Equal(t, []string{}, model1.Aliases)
|
||||
assert.Equal(t, []string{}, model1.Env)
|
||||
assert.Equal(t, 0, model1.UnloadAfter)
|
||||
assert.Equal(t, false, model1.Unlisted)
|
||||
assert.Equal(t, "", model1.UseModelName)
|
||||
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||
}
|
||||
|
||||
// default empty filter exists
|
||||
assert.Equal(t, "", model1.Filters.StripParams)
|
||||
}
|
||||
|
||||
func TestConfig_LoadWindows(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
macros:
|
||||
svr-path: "path/to/server"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: ${svr-path} --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
model3:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "mthree"
|
||||
checkEndpoint: "/"
|
||||
model4:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8082"
|
||||
checkEndpoint: "/"
|
||||
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
groups:
|
||||
group1:
|
||||
swap: true
|
||||
exclusive: false
|
||||
members: ["model2"]
|
||||
forever:
|
||||
exclusive: false
|
||||
persistent: true
|
||||
members:
|
||||
- "model4"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
DEFAULT_GROUP_ID: {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model3"},
|
||||
},
|
||||
"group1": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Members: []string{"model2"},
|
||||
},
|
||||
"forever": {
|
||||
Swap: true,
|
||||
Exclusive: false,
|
||||
Persistent: true,
|
||||
Members: []string{"model4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||
var ProtectedParams = []string{"model"}
|
||||
|
||||
// Filters contains filter settings for modifying request parameters
|
||||
// Used by both models and peers
|
||||
type Filters struct {
|
||||
// StripParams is a comma-separated list of parameters to remove from requests
|
||||
// The "model" parameter can never be removed
|
||||
StripParams string `yaml:"stripParams"`
|
||||
|
||||
// SetParams is a dictionary of parameters to set/override in requests
|
||||
// Protected params (like "model") cannot be set
|
||||
SetParams map[string]any `yaml:"setParams"`
|
||||
|
||||
// SetParamsByID maps requested model IDs to parameters to set/override in requests.
|
||||
// Useful with aliases: a single loaded model can behave differently depending on
|
||||
// which alias the client used. Applied after SetParams, so it can override those values.
|
||||
// Protected params (like "model") cannot be set.
|
||||
SetParamsByID map[string]map[string]any `yaml:"setParamsByID"`
|
||||
}
|
||||
|
||||
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||
// with duplicates, empty strings, and protected params removed
|
||||
func (f Filters) SanitizedStripParams() []string {
|
||||
if f.StripParams == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
// Skip protected params, empty strings, and duplicates
|
||||
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slices.Sort(cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// SanitizedSetParamsByID returns the params to set for the given requestedModelID,
|
||||
// with protected params removed and keys sorted for consistent iteration order.
|
||||
// Returns nil if the ID has no entry or all its params are protected.
|
||||
func (f Filters) SanitizedSetParamsByID(requestedModelID string) (map[string]any, []string) {
|
||||
if len(f.SetParamsByID) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
params, found := f.SetParamsByID[requestedModelID]
|
||||
if !found || len(params) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
result := make(map[string]any, len(params))
|
||||
keys := make([]string, 0, len(params))
|
||||
for key, value := range params {
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return result, keys
|
||||
}
|
||||
|
||||
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||
// and keys sorted for consistent iteration order
|
||||
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||
if len(f.SetParams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any, len(f.SetParams))
|
||||
keys := make([]string, 0, len(f.SetParams))
|
||||
|
||||
for key, value := range f.SetParams {
|
||||
// Skip protected params
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Sort keys for consistent ordering
|
||||
sort.Strings(keys)
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, keys
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripParams string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
stripParams: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single param",
|
||||
stripParams: "temperature",
|
||||
want: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
stripParams: "temperature, top_p, top_k",
|
||||
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||
},
|
||||
{
|
||||
name: "model param filtered",
|
||||
stripParams: "model, temperature, top_p",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "only model param",
|
||||
stripParams: "model",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "duplicates removed",
|
||||
stripParams: "temperature, top_p, temperature",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace",
|
||||
stripParams: " temperature , top_p ",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "empty values filtered",
|
||||
stripParams: "temperature,,top_p,",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{StripParams: tt.stripParams}
|
||||
got := f.SanitizedStripParams()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParams map[string]any
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty setParams",
|
||||
setParams: nil,
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
setParams: map[string]any{},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "normal params",
|
||||
setParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected model param filtered",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "complex nested values",
|
||||
setParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantKeys: []string{"provider", "transforms"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParams: tt.setParams}
|
||||
gotParams, gotKeys := f.SanitizedSetParams()
|
||||
|
||||
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||
for i, key := range gotKeys {
|
||||
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||
}
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams, "expected nil params")
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||
for key, wantValue := range tt.wantParams {
|
||||
gotValue, exists := gotParams[key]
|
||||
assert.True(t, exists, "missing key: %s", key)
|
||||
// Simple comparison for basic types
|
||||
switch v := wantValue.(type) {
|
||||
case string, int, float64, bool:
|
||||
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParamsByID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParamsByID map[string]map[string]any
|
||||
requestedModelID string
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty SetParamsByID returns nil",
|
||||
setParamsByID: nil,
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map returns nil",
|
||||
setParamsByID: map[string]map[string]any{},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "non-matching model ID returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model2": {"temperature": 0.9},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "matching model ID returns correct params",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {"temperature": 0.7, "top_p": 0.9},
|
||||
"model2": {"temperature": 0.5},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected param model is filtered out",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param returns nil",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "keys are sorted",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1": {
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
},
|
||||
requestedModelID: "model1",
|
||||
wantParams: map[string]any{
|
||||
"z_param": "z",
|
||||
"a_param": "a",
|
||||
"m_param": "m",
|
||||
},
|
||||
wantKeys: []string{"a_param", "m_param", "z_param"},
|
||||
},
|
||||
{
|
||||
name: "alias style key lookup",
|
||||
setParamsByID: map[string]map[string]any{
|
||||
"model1:high": {"reasoning_effort": "high"},
|
||||
"model1:low": {"reasoning_effort": "low"},
|
||||
},
|
||||
requestedModelID: "model1:high",
|
||||
wantParams: map[string]any{
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
wantKeys: []string{"reasoning_effort"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParamsByID: tt.setParamsByID}
|
||||
gotParams, gotKeys := f.SanitizedSetParamsByID(tt.requestedModelID)
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams)
|
||||
assert.Nil(t, gotKeys)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantKeys, gotKeys)
|
||||
assert.Equal(t, tt.wantParams, gotParams)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedParams(t *testing.T) {
|
||||
// Verify that "model" is protected
|
||||
assert.Contains(t, ProtectedParams, "model")
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test macro-in-macro basic substitution
|
||||
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-A"
|
||||
"B": "prefix-${A}-suffix"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${B}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test LIFO substitution order with 3+ macro levels
|
||||
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"base": "/models"
|
||||
"path": "${base}/llama"
|
||||
"full": "${path}/model.gguf"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: load ${full}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test MODEL_ID in global macro used by model
|
||||
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||
|
||||
models:
|
||||
my-model:
|
||||
cmd: ${podman-llama} -m model.gguf
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||
}
|
||||
|
||||
// Test model macro overrides global macro in substitution
|
||||
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"tag": "global"
|
||||
"msg": "value-${tag}"
|
||||
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
"tag": "model-level"
|
||||
cmd: echo ${msg}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||
}
|
||||
|
||||
// Test self-reference detection error
|
||||
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"recursive": "value-${recursive}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${recursive}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "recursive")
|
||||
assert.Contains(t, err.Error(), "self-reference")
|
||||
}
|
||||
|
||||
// Test undefined macro reference error
|
||||
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||
content := `
|
||||
startPort: 10000
|
||||
macros:
|
||||
"A": "value-${UNDEFINED}"
|
||||
|
||||
models:
|
||||
test:
|
||||
cmd: echo ${A}
|
||||
proxy: http://localhost:8080
|
||||
`
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||
UnloadAfter int `yaml:"ttl"`
|
||||
Unlisted bool `yaml:"unlisted"`
|
||||
UseModelName string `yaml:"useModelName"`
|
||||
|
||||
// #179 for /v1/models
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
// Limit concurrency of HTTP requests to process
|
||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||
|
||||
// Model filters see issue #174
|
||||
Filters ModelFilters `yaml:"filters"`
|
||||
|
||||
// Macros: see #264
|
||||
// Model level macros take precedence over the global macros
|
||||
Macros MacroList `yaml:"macros"`
|
||||
|
||||
// Metadata: see #264
|
||||
// Arbitrary metadata that can be exposed through the API
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
|
||||
// override global setting
|
||||
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelConfig ModelConfig
|
||||
defaults := rawModelConfig{
|
||||
Cmd: "",
|
||||
CmdStop: "",
|
||||
Proxy: "http://localhost:${PORT}",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
Name: "",
|
||||
Description: "",
|
||||
}
|
||||
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" {
|
||||
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = ModelConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||
// See issue #174
|
||||
type ModelFilters struct {
|
||||
Filters `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal with the old field name for backwards compatibility
|
||||
if defaults.StripParams == "" {
|
||||
var legacy struct {
|
||||
StripParams string `yaml:"strip_params"`
|
||||
}
|
||||
if legacyErr := unmarshal(&legacy); legacyErr != nil {
|
||||
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
|
||||
}
|
||||
defaults.StripParams = legacy.StripParams
|
||||
}
|
||||
|
||||
*m = ModelFilters(defaults)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||
// Returns ([]string, error) to match existing API
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
return f.Filters.SanitizedStripParams(), nil
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_ModelFilters(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
default_strip: "temperature, top_p"
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
# macros inserted and list is cleaned of duplicates and empty strings
|
||||
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
# check for strip_params (legacy field name) compatibility
|
||||
legacy:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
for modelId, modelConfig := range config.Models {
|
||||
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
|
||||
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||
if assert.NoError(t, err) {
|
||||
// model has been removed
|
||||
// empty strings have been removed
|
||||
// duplicates have been removed
|
||||
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelSendLoadingState(t *testing.T) {
|
||||
content := `
|
||||
sendLoadingState: true
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
sendLoadingState: false
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, config.SendLoadingState)
|
||||
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
|
||||
assert.False(t, *config.Models["model1"].SendLoadingState)
|
||||
}
|
||||
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAlias(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"${MODEL_ID}:high":
|
||||
reasoning_effort: high
|
||||
"${MODEL_ID}:low":
|
||||
reasoning_effort: low
|
||||
`
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Keys (other than the model's own ID) should be registered as aliases
|
||||
realName, found := cfg.RealModelName("model1:high")
|
||||
assert.True(t, found, "model1:high should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
realName, found = cfg.RealModelName("model1:low")
|
||||
assert.True(t, found, "model1:low should be an auto-registered alias")
|
||||
assert.Equal(t, "model1", realName)
|
||||
|
||||
// Auto-aliases should also appear in modelConfig.Aliases
|
||||
aliases := cfg.Models["model1"].Aliases
|
||||
assert.Contains(t, aliases, "model1:high")
|
||||
assert.Contains(t, aliases, "model1:low")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithModelID(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
model2:
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "conflicts with an existing model ID")
|
||||
}
|
||||
|
||||
func TestConfig_SetParamsByIDAutoAliasConflictWithOtherModel(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: high
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
setParamsByID:
|
||||
"shared-alias":
|
||||
reasoning_effort: low
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.ErrorContains(t, err, "duplicate alias")
|
||||
}
|
||||
|
||||
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
stripParams: "top_k"
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelConfig := config.Models["model1"]
|
||||
|
||||
// Check stripParams
|
||||
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||
|
||||
// Check setParams
|
||||
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||
assert.NotNil(t, setParams)
|
||||
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type PeerDictionaryConfig map[string]PeerConfig
|
||||
type PeerConfig struct {
|
||||
Proxy string `yaml:"proxy"`
|
||||
ProxyURL *url.URL `yaml:"-"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Models []string `yaml:"models"`
|
||||
Filters Filters `yaml:"filters"`
|
||||
}
|
||||
|
||||
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawPeerConfig PeerConfig
|
||||
defaults := rawPeerConfig{
|
||||
Proxy: "",
|
||||
ApiKey: "",
|
||||
Models: []string{},
|
||||
Filters: Filters{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate proxy is not empty
|
||||
if defaults.Proxy == "" {
|
||||
return fmt.Errorf("proxy is required")
|
||||
}
|
||||
|
||||
// Validate proxy is a valid URL and store the parsed value
|
||||
parsedURL, err := url.Parse(defaults.Proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||
}
|
||||
defaults.ProxyURL = parsedURL
|
||||
|
||||
// Validate models is not empty
|
||||
if len(defaults.Models) == 0 {
|
||||
return fmt.Errorf("peer models can not be empty")
|
||||
}
|
||||
|
||||
*c = PeerConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
yaml: `
|
||||
proxy: http://192.168.1.23
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid config with apiKey",
|
||||
yaml: `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "missing proxy",
|
||||
yaml: `
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "empty proxy",
|
||||
yaml: `
|
||||
proxy: ""
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "invalid proxy URL",
|
||||
yaml: `
|
||||
proxy: "://invalid"
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "invalid peer proxy URL",
|
||||
},
|
||||
{
|
||||
name: "missing models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
models: []
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||
} else if !contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: http://192.168.1.23:8080/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.ProxyURL == nil {
|
||||
t.Fatal("ProxyURL should not be nil")
|
||||
}
|
||||
|
||||
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Scheme != "http" {
|
||||
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Path != "/api" {
|
||||
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||
}
|
||||
|
||||
func searchSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
provider:
|
||||
data_collection: deny
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
|
||||
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||
}
|
||||
|
||||
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider should be a map")
|
||||
}
|
||||
if provider["data_collection"] != "deny" {
|
||||
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
stripParams: "temperature, top_p"
|
||||
setParams:
|
||||
max_tokens: 1000
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check stripParams
|
||||
stripParams := config.Filters.SanitizedStripParams()
|
||||
if len(stripParams) != 2 {
|
||||
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||
}
|
||||
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||
t.Errorf("unexpected strip params: %v", stripParams)
|
||||
}
|
||||
|
||||
// Check setParams
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||
}
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
// Create a temporary YAML file for testing
|
||||
tempDir, err := os.MkdirTemp("", "test-config")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
aliases:
|
||||
- "m1"
|
||||
- "model-one"
|
||||
env:
|
||||
- "VAR1=value1"
|
||||
- "VAR2=value2"
|
||||
checkEndpoint: "/health"
|
||||
model2:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8081"
|
||||
aliases:
|
||||
- "m2"
|
||||
checkEndpoint: "/"
|
||||
healthCheckTimeout: 15
|
||||
profiles:
|
||||
test:
|
||||
- model1
|
||||
- model2
|
||||
`
|
||||
|
||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write temporary file: %v", err)
|
||||
}
|
||||
|
||||
// Load the config and verify
|
||||
config, err := LoadConfig(tempFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: nil,
|
||||
CheckEndpoint: "/",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, config)
|
||||
|
||||
realname, found := config.RealModelName("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", realname)
|
||||
}
|
||||
|
||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||
config := &ModelConfig{
|
||||
Cmd: `python model1.py \
|
||||
--arg1 value1 \
|
||||
--arg2 value2`,
|
||||
}
|
||||
|
||||
args, err := config.SanitizedCommand()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||
}
|
||||
|
||||
func TestConfig_FindConfig(t *testing.T) {
|
||||
|
||||
// TODO?
|
||||
// make make this shared between the different tests
|
||||
config := &Config{
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "python model1.py",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "python model2.py",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2", "model-two"},
|
||||
Env: []string{"VAR3=value3", "VAR4=value4"},
|
||||
CheckEndpoint: "/status",
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 10,
|
||||
aliases: map[string]string{
|
||||
"m1": "model1",
|
||||
"model-one": "model1",
|
||||
"m2": "model2",
|
||||
},
|
||||
}
|
||||
|
||||
// Test finding a model by its name
|
||||
modelConfig, modelId, found := config.FindConfig("model1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model by its alias
|
||||
modelConfig, modelId, found = config.FindConfig("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model that does not exist
|
||||
modelConfig, modelId, found = config.FindConfig("model3")
|
||||
assert.False(t, found)
|
||||
assert.Equal(t, "", modelId)
|
||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||
}
|
||||
|
||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
|
||||
// Test a command with spaces and newlines
|
||||
args, err := SanitizeCommand(`python model1.py \
|
||||
-a "double quotes" \
|
||||
--arg2 'single quotes'
|
||||
-s
|
||||
--arg3 123 \
|
||||
--arg4 '"string in string"'
|
||||
-c "'single quoted'"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"python", "model1.py",
|
||||
"-a", "double quotes",
|
||||
"--arg2", "single quotes",
|
||||
"-s",
|
||||
"--arg3", "123",
|
||||
"--arg4", `"string in string"`,
|
||||
"-c", `'single quoted'`,
|
||||
}, args)
|
||||
|
||||
// Test an empty command
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package proxy
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||
type DiscardWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *DiscardWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
}
|
||||
|
||||
// Satisfy the http.Flusher interface for streaming responses
|
||||
func (w *DiscardWriter) Flush() {}
|
||||
@@ -0,0 +1,69 @@
|
||||
package proxy
|
||||
|
||||
// package level registry of the different event types
|
||||
|
||||
const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
const ModelPreloadedEventID = 0x06
|
||||
const InFlightRequestsEventID = 0x07
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
NewState ProcessState
|
||||
OldState ProcessState
|
||||
}
|
||||
|
||||
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||
return ProcessStateChangeEventID
|
||||
}
|
||||
|
||||
type ChatCompletionStats struct {
|
||||
TokensGenerated int
|
||||
}
|
||||
|
||||
func (e ChatCompletionStats) Type() uint32 {
|
||||
return ChatCompletionStatsEventID
|
||||
}
|
||||
|
||||
type ReloadingState int
|
||||
|
||||
const (
|
||||
ReloadingStateStart ReloadingState = iota
|
||||
ReloadingStateEnd
|
||||
)
|
||||
|
||||
type ConfigFileChangedEvent struct {
|
||||
ReloadingState ReloadingState
|
||||
}
|
||||
|
||||
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||
return ConfigFileChangedEventID
|
||||
}
|
||||
|
||||
type LogDataEvent struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (e LogDataEvent) Type() uint32 {
|
||||
return LogDataEventID
|
||||
}
|
||||
|
||||
type ModelPreloadedEvent struct {
|
||||
ModelName string
|
||||
Success bool
|
||||
}
|
||||
|
||||
func (e ModelPreloadedEvent) Type() uint32 {
|
||||
return ModelPreloadedEventID
|
||||
}
|
||||
|
||||
type InFlightRequestsEvent struct {
|
||||
Total int
|
||||
}
|
||||
|
||||
func (e InFlightRequestsEvent) Type() uint32 {
|
||||
return InFlightRequestsEventID
|
||||
}
|
||||
+44
-12
@@ -9,11 +9,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
nextTestPort int = 12000
|
||||
portMutex sync.Mutex
|
||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||
simpleResponderPath = getSimpleResponderPath()
|
||||
)
|
||||
|
||||
// Check if the binary exists
|
||||
@@ -26,6 +30,17 @@ func TestMain(m *testing.M) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "debug":
|
||||
testLogger.SetLogLevel(LevelDebug)
|
||||
case "warn":
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
case "info":
|
||||
testLogger.SetLogLevel(LevelInfo)
|
||||
default:
|
||||
testLogger.SetLogLevel(LevelWarn)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
@@ -33,26 +48,43 @@ func TestMain(m *testing.M) {
|
||||
func getSimpleResponderPath() string {
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
|
||||
if goos == "windows" {
|
||||
return filepath.Join("..", "build", "simple-responder.exe")
|
||||
} else {
|
||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||
}
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||
func getTestPort() int {
|
||||
portMutex.Lock()
|
||||
defer portMutex.Unlock()
|
||||
|
||||
port := nextTestPort
|
||||
nextTestPort++
|
||||
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
||||
return port
|
||||
}
|
||||
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||
binaryPath := getSimpleResponderPath()
|
||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||
}
|
||||
|
||||
// Create a process configuration
|
||||
return ModelConfig{
|
||||
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||
// Convert path to forward slashes for cross-platform compatibility
|
||||
// Windows handles forward slashes in paths correctly
|
||||
cmdPath := filepath.ToSlash(simpleResponderPath)
|
||||
|
||||
// Create a YAML string with just the values we want to set
|
||||
yamlStr := fmt.Sprintf(`
|
||||
cmd: '%s --port %d --silent --respond %s'
|
||||
proxy: "http://127.0.0.1:%d"
|
||||
`, cmdPath, port, expectedMessage, port)
|
||||
|
||||
var cfg config.ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 15 KiB |
@@ -1,14 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>llama-swap</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>llama-swap</h1>
|
||||
<p>
|
||||
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,145 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Logs</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: "Courier New", Courier, monospace;
|
||||
}
|
||||
#log-controls {
|
||||
margin: 0.5em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between; /* Spaces out elements evenly */
|
||||
}
|
||||
#log-controls input {
|
||||
flex: 1;
|
||||
}
|
||||
#log-controls input:focus {
|
||||
outline: none; /* Ensures no outline is shown when the input is focused */
|
||||
}
|
||||
#log-stream {
|
||||
flex: 1;
|
||||
margin: 0.5em;
|
||||
padding: 1em;
|
||||
background: #f4f4f4;
|
||||
overflow-y: auto;
|
||||
white-space: pre-wrap; /* Ensures line wrapping */
|
||||
word-wrap: break-word; /* Ensures long words wrap */
|
||||
}
|
||||
|
||||
.regex-error {
|
||||
background-color: #ff0000 !important;
|
||||
}
|
||||
|
||||
/* Dark mode styles */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body {
|
||||
background-color: #333;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-stream {
|
||||
background: #444;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
#log-controls input {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
|
||||
#log-controls button {
|
||||
background: #555;
|
||||
color: #fff;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<pre id="log-stream">Waiting for logs...</pre>
|
||||
<div id="log-controls">
|
||||
<input type="text" id="filter-input" placeholder="regex filter">
|
||||
<button id="clear-button">clear</button>
|
||||
</div>
|
||||
<script>
|
||||
const logStream = document.getElementById('log-stream');
|
||||
const filterInput = document.getElementById('filter-input');
|
||||
var logData = "";
|
||||
let regexFilter = null;
|
||||
|
||||
function setupEventSource() {
|
||||
if (typeof(EventSource) !== "undefined") {
|
||||
const eventSource = new EventSource("/logs/streamSSE");
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
logData += event.data;
|
||||
render()
|
||||
};
|
||||
|
||||
eventSource.onerror = function(err) {
|
||||
logData = "EventSource failed: " + err.message;
|
||||
};
|
||||
} else {
|
||||
logData = "SSE Not supported by this browser."
|
||||
}
|
||||
}
|
||||
|
||||
// poor-ai's react ¯\_(ツ)_/¯
|
||||
function render() {
|
||||
if (regexFilter) {
|
||||
const lines = logData.split('\n');
|
||||
const filteredLines = lines.filter(line => {
|
||||
return regexFilter === null || regexFilter.test(line);
|
||||
});
|
||||
|
||||
if (filteredLines.length > 0) {
|
||||
logStream.textContent = filteredLines.join('\n') + '\n';
|
||||
} else {
|
||||
logStream.textContent = "";
|
||||
}
|
||||
} else {
|
||||
logStream.textContent = logData;
|
||||
}
|
||||
|
||||
logStream.scrollTop = logStream.scrollHeight;
|
||||
}
|
||||
|
||||
function updateFilter() {
|
||||
const pattern = filterInput.value.trim();
|
||||
filterInput.classList.remove('regex-error');
|
||||
if (pattern) {
|
||||
try {
|
||||
regexFilter = new RegExp(pattern);
|
||||
} catch (e) {
|
||||
console.error("Invalid regex pattern:", e);
|
||||
regexFilter = null;
|
||||
filterInput.classList.add('regex-error');
|
||||
return
|
||||
}
|
||||
} else {
|
||||
regexFilter = null;
|
||||
}
|
||||
|
||||
render();
|
||||
}
|
||||
|
||||
filterInput.addEventListener('input', updateFilter);
|
||||
document.getElementById('clear-button').addEventListener('click', () => {
|
||||
filterInput.value = "";
|
||||
regexFilter = null;
|
||||
render();
|
||||
});
|
||||
setupEventSource();
|
||||
updateFilter();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,10 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed html
|
||||
var htmlFiles embed.FS
|
||||
|
||||
func getHTMLFile(path string) ([]byte, error) {
|
||||
return htmlFiles.ReadFile("html/" + path)
|
||||
}
|
||||
+216
-43
@@ -1,20 +1,121 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
// circularBuffer is a fixed-size circular byte buffer that overwrites
|
||||
// oldest data when full. It provides O(1) writes and O(n) reads.
|
||||
type circularBuffer struct {
|
||||
data []byte // pre-allocated capacity
|
||||
head int // next write position
|
||||
size int // current number of bytes stored (0 to cap)
|
||||
}
|
||||
|
||||
func newCircularBuffer(capacity int) *circularBuffer {
|
||||
return &circularBuffer{
|
||||
data: make([]byte, capacity),
|
||||
head: 0,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Write appends bytes to the buffer, overwriting oldest data when full.
|
||||
// Data is copied into the internal buffer (not stored by reference).
|
||||
func (cb *circularBuffer) Write(p []byte) {
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cap := len(cb.data)
|
||||
|
||||
// If input is larger than capacity, only keep the last cap bytes
|
||||
if len(p) >= cap {
|
||||
copy(cb.data, p[len(p)-cap:])
|
||||
cb.head = 0
|
||||
cb.size = cap
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate how much space is available from head to end of buffer
|
||||
firstPart := cap - cb.head
|
||||
if firstPart >= len(p) {
|
||||
// All data fits without wrapping
|
||||
copy(cb.data[cb.head:], p)
|
||||
cb.head = (cb.head + len(p)) % cap
|
||||
} else {
|
||||
// Data wraps around
|
||||
copy(cb.data[cb.head:], p[:firstPart])
|
||||
copy(cb.data[:len(p)-firstPart], p[firstPart:])
|
||||
cb.head = len(p) - firstPart
|
||||
}
|
||||
|
||||
// Update size
|
||||
cb.size += len(p)
|
||||
if cb.size > cap {
|
||||
cb.size = cap
|
||||
}
|
||||
}
|
||||
|
||||
// GetHistory returns all buffered data in correct order (oldest to newest).
|
||||
// Returns a new slice (copy), not a view into internal buffer.
|
||||
func (cb *circularBuffer) GetHistory() []byte {
|
||||
if cb.size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]byte, cb.size)
|
||||
cap := len(cb.data)
|
||||
|
||||
// Calculate start position (oldest data)
|
||||
start := (cb.head - cb.size + cap) % cap
|
||||
|
||||
if start+cb.size <= cap {
|
||||
// Data is contiguous, single copy
|
||||
copy(result, cb.data[start:start+cb.size])
|
||||
} else {
|
||||
// Data wraps around, two copies
|
||||
firstPart := cap - start
|
||||
copy(result[:firstPart], cb.data[start:])
|
||||
copy(result[firstPart:], cb.data[:cb.size-firstPart])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
LevelDebug LogLevel = iota
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
|
||||
LogBufferSize = 100 * 1024
|
||||
)
|
||||
|
||||
type LogMonitor struct {
|
||||
clients map[chan []byte]bool
|
||||
eventbus *event.Dispatcher
|
||||
mu sync.RWMutex
|
||||
buffer *ring.Ring
|
||||
buffer *circularBuffer
|
||||
bufferMu sync.RWMutex
|
||||
|
||||
// typically this can be os.Stdout
|
||||
stdout io.Writer
|
||||
|
||||
// logging levels
|
||||
level LogLevel
|
||||
prefix string
|
||||
|
||||
// timestamps
|
||||
timeFormat string
|
||||
}
|
||||
|
||||
func NewLogMonitor() *LogMonitor {
|
||||
@@ -23,9 +124,12 @@ func NewLogMonitor() *LogMonitor {
|
||||
|
||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||
return &LogMonitor{
|
||||
clients: make(map[chan []byte]bool),
|
||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||
stdout: stdout,
|
||||
eventbus: event.NewDispatcherConfig(1000),
|
||||
buffer: nil, // lazy initialized on first Write
|
||||
stdout: stdout,
|
||||
level: LevelInfo,
|
||||
prefix: "",
|
||||
timeFormat: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
w.bufferMu.Lock()
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.buffer.Value = bufferCopy
|
||||
w.buffer = w.buffer.Next()
|
||||
if w.buffer == nil {
|
||||
w.buffer = newCircularBuffer(LogBufferSize)
|
||||
}
|
||||
w.buffer.Write(p)
|
||||
w.bufferMu.Unlock()
|
||||
|
||||
// Make a copy for broadcast to preserve immutability
|
||||
bufferCopy := make([]byte, len(p))
|
||||
copy(bufferCopy, p)
|
||||
w.broadcast(bufferCopy)
|
||||
return n, nil
|
||||
}
|
||||
@@ -53,44 +160,110 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
|
||||
func (w *LogMonitor) GetHistory() []byte {
|
||||
w.bufferMu.RLock()
|
||||
defer w.bufferMu.RUnlock()
|
||||
if w.buffer == nil {
|
||||
return nil
|
||||
}
|
||||
return w.buffer.GetHistory()
|
||||
}
|
||||
|
||||
var history []byte
|
||||
w.buffer.Do(func(p any) {
|
||||
if p != nil {
|
||||
if content, ok := p.([]byte); ok {
|
||||
history = append(history, content...)
|
||||
}
|
||||
}
|
||||
// Clear releases the buffer memory, making it eligible for GC.
|
||||
// The buffer will be lazily re-allocated on the next Write.
|
||||
func (w *LogMonitor) Clear() {
|
||||
w.bufferMu.Lock()
|
||||
w.buffer = nil
|
||||
w.bufferMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
|
||||
callback(e.Data)
|
||||
})
|
||||
return history
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Subscribe() chan []byte {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
ch := make(chan []byte, 100)
|
||||
w.clients[ch] = true
|
||||
return ch
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
delete(w.clients, ch)
|
||||
close(ch)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) broadcast(msg []byte) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
event.Publish(w.eventbus, LogDataEvent{Data: msg})
|
||||
}
|
||||
|
||||
for client := range w.clients {
|
||||
select {
|
||||
case client <- msg:
|
||||
default:
|
||||
// If client buffer is full, skip
|
||||
}
|
||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.prefix = prefix
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.level = level
|
||||
}
|
||||
|
||||
func (w *LogMonitor) SetLogTimeFormat(timeFormat string) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.timeFormat = timeFormat
|
||||
}
|
||||
|
||||
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||
prefix := ""
|
||||
if w.prefix != "" {
|
||||
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||
}
|
||||
timestamp := ""
|
||||
if w.timeFormat != "" {
|
||||
timestamp = fmt.Sprintf("%s ", time.Now().Format(w.timeFormat))
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s%s[%s] %s\n", timestamp, prefix, level, msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||
if level < w.level {
|
||||
return
|
||||
}
|
||||
w.Write(w.formatMessage(level.String(), msg))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Debug(msg string) {
|
||||
w.log(LevelDebug, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Info(msg string) {
|
||||
w.log(LevelInfo, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Warn(msg string) {
|
||||
w.log(LevelWarn, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Error(msg string) {
|
||||
w.log(LevelError, msg)
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
|
||||
w.log(LevelDebug, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Infof(format string, args ...interface{}) {
|
||||
w.log(LevelInfo, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
|
||||
w.log(LevelWarn, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
|
||||
w.log(LevelError, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l LogLevel) String() string {
|
||||
switch l {
|
||||
case LevelDebug:
|
||||
return "DEBUG"
|
||||
case LevelInfo:
|
||||
return "INFO"
|
||||
case LevelWarn:
|
||||
return "WARN"
|
||||
case LevelError:
|
||||
return "ERROR"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
+243
-22
@@ -3,45 +3,38 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLogMonitor(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Test subscription
|
||||
client1 := logMonitor.Subscribe()
|
||||
client2 := logMonitor.Subscribe()
|
||||
|
||||
defer logMonitor.Unsubscribe(client1)
|
||||
defer logMonitor.Unsubscribe(client2)
|
||||
// A WaitGroup is used to wait for all the expected writes to complete
|
||||
var wg sync.WaitGroup
|
||||
|
||||
client1Messages := make([]byte, 0)
|
||||
client2Messages := make([]byte, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
defer logMonitor.OnLogData(func(data []byte) {
|
||||
client1Messages = append(client1Messages, data...)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case data := <-client1:
|
||||
client1Messages = append(client1Messages, data...)
|
||||
case data := <-client2:
|
||||
client2Messages = append(client2Messages, data...)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer logMonitor.OnLogData(func(data []byte) {
|
||||
client2Messages = append(client2Messages, data...)
|
||||
wg.Done()
|
||||
})()
|
||||
|
||||
wg.Add(6) // 2 x 3 writes
|
||||
|
||||
logMonitor.Write([]byte("1"))
|
||||
logMonitor.Write([]byte("2"))
|
||||
logMonitor.Write([]byte("3"))
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
// wait for all writes to complete
|
||||
wg.Wait()
|
||||
|
||||
// Check the buffer
|
||||
@@ -93,3 +86,231 @@ func TestWrite_ImmutableBuffer(t *testing.T) {
|
||||
t.Errorf("Expected history to be %q, got %q", expected, history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite_LogTimeFormat(t *testing.T) {
|
||||
// Create a new LogMonitor instance
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Enable timestamps
|
||||
lm.timeFormat = time.RFC3339
|
||||
|
||||
// Write the message to the LogMonitor
|
||||
lm.Info("Hello, World!")
|
||||
|
||||
// Get the history from the LogMonitor
|
||||
history := lm.GetHistory()
|
||||
|
||||
timestamp := ""
|
||||
fields := strings.Fields(string(history))
|
||||
if len(fields) > 0 {
|
||||
timestamp = fields[0]
|
||||
} else {
|
||||
t.Fatalf("Cannot extract string from history")
|
||||
}
|
||||
|
||||
_, err := time.Parse(time.RFC3339, timestamp)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot find timestamp: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_WrapAround(t *testing.T) {
|
||||
// Create a small buffer to test wrap-around
|
||||
cb := newCircularBuffer(10)
|
||||
|
||||
// Write "hello" (5 bytes)
|
||||
cb.Write([]byte("hello"))
|
||||
if got := string(cb.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Write "world" (5 bytes) - buffer now full
|
||||
cb.Write([]byte("world"))
|
||||
if got := string(cb.GetHistory()); got != "helloworld" {
|
||||
t.Errorf("Expected 'helloworld', got %q", got)
|
||||
}
|
||||
|
||||
// Write "12345" (5 bytes) - should overwrite "hello"
|
||||
cb.Write([]byte("12345"))
|
||||
if got := string(cb.GetHistory()); got != "world12345" {
|
||||
t.Errorf("Expected 'world12345', got %q", got)
|
||||
}
|
||||
|
||||
// Write data larger than buffer capacity
|
||||
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
|
||||
if got := string(cb.GetHistory()); got != "ghijklmnop" {
|
||||
t.Errorf("Expected 'ghijklmnop', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
|
||||
// Test empty buffer
|
||||
cb := newCircularBuffer(10)
|
||||
if got := cb.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil for empty buffer, got %q", got)
|
||||
}
|
||||
|
||||
// Test exact capacity
|
||||
cb.Write([]byte("1234567890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
|
||||
// Test write exactly at capacity boundary
|
||||
cb = newCircularBuffer(10)
|
||||
cb.Write([]byte("12345"))
|
||||
cb.Write([]byte("67890"))
|
||||
if got := string(cb.GetHistory()); got != "1234567890" {
|
||||
t.Errorf("Expected '1234567890', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_LazyInit(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Buffer should be nil before any writes
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil before first write")
|
||||
}
|
||||
|
||||
// GetHistory should return nil when buffer is nil
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history before first write, got %q", got)
|
||||
}
|
||||
|
||||
// Write should lazily initialize the buffer
|
||||
lm.Write([]byte("test"))
|
||||
|
||||
if lm.buffer == nil {
|
||||
t.Error("Expected buffer to be initialized after write")
|
||||
}
|
||||
|
||||
if got := string(lm.GetHistory()); got != "test" {
|
||||
t.Errorf("Expected 'test', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_Clear(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write some data
|
||||
lm.Write([]byte("hello"))
|
||||
if got := string(lm.GetHistory()); got != "hello" {
|
||||
t.Errorf("Expected 'hello', got %q", got)
|
||||
}
|
||||
|
||||
// Clear should release the buffer
|
||||
lm.Clear()
|
||||
|
||||
if lm.buffer != nil {
|
||||
t.Error("Expected buffer to be nil after Clear")
|
||||
}
|
||||
|
||||
if got := lm.GetHistory(); got != nil {
|
||||
t.Errorf("Expected nil history after Clear, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogMonitor_ClearAndReuse(t *testing.T) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
// Write, clear, then write again
|
||||
lm.Write([]byte("first"))
|
||||
lm.Clear()
|
||||
lm.Write([]byte("second"))
|
||||
|
||||
if got := string(lm.GetHistory()); got != "second" {
|
||||
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogMonitorWrite(b *testing.B) {
|
||||
// Test data of varying sizes
|
||||
smallMsg := []byte("small message\n")
|
||||
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
|
||||
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
|
||||
|
||||
b.Run("SmallWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(smallMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("MediumWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LargeWrite", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(largeMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WithSubscribers", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Add some subscribers
|
||||
for i := 0; i < 5; i++ {
|
||||
lm.OnLogData(func(data []byte) {})
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetHistory", func(b *testing.B) {
|
||||
lm := NewLogMonitorWriter(io.Discard)
|
||||
// Pre-populate with data
|
||||
for i := 0; i < 1000; i++ {
|
||||
lm.Write(mediumMsg)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lm.GetHistory()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
Benchmark Results - MBP M1 Pro
|
||||
|
||||
Before (ring.Ring):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 43 ns | 40 B | 2 |
|
||||
| MediumWrite (241B) | 76 ns | 264 B | 2 |
|
||||
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
|
||||
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
|
||||
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
|
||||
|
||||
After (circularBuffer 10KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 67 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
|
||||
|
||||
After (circularBuffer 100KB):
|
||||
| Benchmark | ns/op | bytes/op | allocs/op |
|
||||
|---------------------------------|------------|-----------|-----------|
|
||||
| SmallWrite (14B) | 26 ns | 16 B | 1 |
|
||||
| MediumWrite (241B) | 66 ns | 240 B | 1 |
|
||||
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
|
||||
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
|
||||
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
|
||||
|
||||
Summary:
|
||||
- GetHistory: 139x faster (10KB), 18x faster (100KB)
|
||||
- Allocations: reduced from 2 to 1 across all operations
|
||||
- Small/medium writes: ~1.1-1.6x faster
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,515 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||
type TokenMetrics struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
}
|
||||
|
||||
type ReqRespCapture struct {
|
||||
ID int `json:"id"`
|
||||
ReqPath string `json:"req_path"`
|
||||
ReqHeaders map[string]string `json:"req_headers"`
|
||||
ReqBody []byte `json:"req_body"`
|
||||
RespHeaders map[string]string `json:"resp_headers"`
|
||||
RespBody []byte `json:"resp_body"`
|
||||
}
|
||||
|
||||
// Size returns the approximate memory usage of this capture in bytes
|
||||
func (c *ReqRespCapture) Size() int {
|
||||
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
|
||||
for k, v := range c.ReqHeaders {
|
||||
size += len(k) + len(v)
|
||||
}
|
||||
for k, v := range c.RespHeaders {
|
||||
size += len(k) + len(v)
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
type TokenMetricsEvent struct {
|
||||
Metrics TokenMetrics
|
||||
}
|
||||
|
||||
func (e TokenMetricsEvent) Type() uint32 {
|
||||
return TokenMetricsEventID // defined in events.go
|
||||
}
|
||||
|
||||
// metricsMonitor parses llama-server output for token statistics
|
||||
type metricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics []TokenMetrics
|
||||
maxMetrics int
|
||||
nextID int
|
||||
logger *LogMonitor
|
||||
|
||||
// capture fields
|
||||
enableCaptures bool
|
||||
captures map[int]ReqRespCapture // map for O(1) lookup by ID
|
||||
captureOrder []int // track insertion order for FIFO eviction
|
||||
captureSize int // current total size in bytes
|
||||
maxCaptureSize int // max bytes for captures
|
||||
}
|
||||
|
||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||
// capture buffer size in megabytes; 0 disables captures.
|
||||
func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int) *metricsMonitor {
|
||||
return &metricsMonitor{
|
||||
logger: logger,
|
||||
maxMetrics: maxMetrics,
|
||||
enableCaptures: captureBufferMB > 0,
|
||||
captures: make(map[int]ReqRespCapture),
|
||||
captureOrder: make([]int, 0),
|
||||
captureSize: 0,
|
||||
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||
}
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event.
|
||||
// Returns the assigned metric ID.
|
||||
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics = append(mp.metrics, metric)
|
||||
if len(mp.metrics) > mp.maxMetrics {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
return metric.ID
|
||||
}
|
||||
|
||||
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize.
|
||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||
if !mp.enableCaptures {
|
||||
return
|
||||
}
|
||||
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
captureSize := capture.Size()
|
||||
if captureSize > mp.maxCaptureSize {
|
||||
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest (FIFO) until room available
|
||||
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||
oldestID := mp.captureOrder[0]
|
||||
mp.captureOrder = mp.captureOrder[1:]
|
||||
if evicted, exists := mp.captures[oldestID]; exists {
|
||||
mp.captureSize -= evicted.Size()
|
||||
delete(mp.captures, oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
mp.captures[capture.ID] = capture
|
||||
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||
mp.captureSize += captureSize
|
||||
}
|
||||
|
||||
// getCaptureByID returns a capture by its ID, or nil if not found.
|
||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
if capture, exists := mp.captures[id]; exists {
|
||||
return &capture
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMetrics returns a copy of the current metrics
|
||||
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := make([]TokenMetrics, len(mp.metrics))
|
||||
copy(result, mp.metrics)
|
||||
return result
|
||||
}
|
||||
|
||||
// getMetricsJSON returns metrics as JSON
|
||||
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
return json.Marshal(mp.metrics)
|
||||
}
|
||||
|
||||
// wrapHandler wraps the proxy handler to extract token metrics
|
||||
// if wrapHandler returns an error it is safe to assume that no
|
||||
// data was sent to the client
|
||||
func (mp *metricsMonitor) wrapHandler(
|
||||
modelID string,
|
||||
writer gin.ResponseWriter,
|
||||
request *http.Request,
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
// Capture request body and headers if captures enabled
|
||||
var reqBody []byte
|
||||
var reqHeaders map[string]string
|
||||
if mp.enableCaptures {
|
||||
if request.Body != nil {
|
||||
var err error
|
||||
reqBody, err = io.ReadAll(request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body for capture: %w", err)
|
||||
}
|
||||
request.Body.Close()
|
||||
request.Body = io.NopCloser(bytes.NewBuffer(reqBody))
|
||||
}
|
||||
reqHeaders = make(map[string]string)
|
||||
for key, values := range request.Header {
|
||||
if len(values) > 0 {
|
||||
reqHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(reqHeaders)
|
||||
}
|
||||
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// after this point we have to assume that data was sent to the client
|
||||
// and we can only log errors but not send them to clients
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize default metrics - these will always be recorded
|
||||
tm := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsed
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
// extract timings for infill - response is an array, timings are in the last element
|
||||
// see #463
|
||||
if strings.HasPrefix(request.URL.Path, "/infill") {
|
||||
if arr := parsed.Array(); len(arr) > 0 {
|
||||
timings = arr[len(arr)-1].Get("timings")
|
||||
}
|
||||
}
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsedMetrics
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Build capture if enabled and determine if it will be stored
|
||||
var capture *ReqRespCapture
|
||||
if mp.enableCaptures {
|
||||
respHeaders := make(map[string]string)
|
||||
for key, values := range recorder.Header() {
|
||||
if len(values) > 0 {
|
||||
respHeaders[key] = values[0]
|
||||
}
|
||||
}
|
||||
redactHeaders(respHeaders)
|
||||
delete(respHeaders, "Content-Encoding")
|
||||
capture = &ReqRespCapture{
|
||||
ReqPath: request.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqBody: reqBody,
|
||||
RespHeaders: respHeaders,
|
||||
RespBody: body,
|
||||
}
|
||||
// Only set HasCapture if the capture will actually be stored (not too large)
|
||||
if capture.Size() <= mp.maxCaptureSize {
|
||||
tm.HasCapture = true
|
||||
}
|
||||
}
|
||||
|
||||
metricID := mp.addMetrics(tm)
|
||||
|
||||
// Store capture if enabled
|
||||
if capture != nil {
|
||||
capture.ID = metricID
|
||||
mp.addCapture(*capture)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
|
||||
// Iterate **backwards** through the body looking for the data payload with
|
||||
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
||||
|
||||
// Start from the end of the body and scan backwards for newlines
|
||||
pos := len(body)
|
||||
for pos > 0 {
|
||||
// Find the previous newline (or start of body)
|
||||
lineStart := bytes.LastIndexByte(body[:pos], '\n')
|
||||
if lineStart == -1 {
|
||||
lineStart = 0
|
||||
} else {
|
||||
lineStart++ // Move past the newline
|
||||
}
|
||||
|
||||
line := bytes.TrimSpace(body[lineStart:pos])
|
||||
pos = lineStart - 1 // Move position before the newline for next iteration
|
||||
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// SSE payload always follows "data:"
|
||||
prefix := []byte("data:")
|
||||
if !bytes.HasPrefix(line, prefix) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len(prefix):])
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
// [DONE] line itself contains nothing of interest.
|
||||
continue
|
||||
}
|
||||
|
||||
if gjson.ValidBytes(data) {
|
||||
parsed := gjson.ParseBytes(data)
|
||||
usage := parsed.Get("usage")
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
return parseMetrics(modelID, start, usage, timings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||
}
|
||||
|
||||
func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
|
||||
// default values
|
||||
cachedTokens := -1 // unknown or missing data
|
||||
outputTokens := 0
|
||||
inputTokens := 0
|
||||
|
||||
// timings data
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
durationMs := int(time.Since(start).Milliseconds())
|
||||
|
||||
if usage.Exists() {
|
||||
if pt := usage.Get("prompt_tokens"); pt.Exists() {
|
||||
// v1/chat/completions
|
||||
inputTokens = int(pt.Int())
|
||||
} else if it := usage.Get("input_tokens"); it.Exists() {
|
||||
// v1/messages
|
||||
inputTokens = int(it.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("completion_tokens"); ct.Exists() {
|
||||
// v1/chat/completions
|
||||
outputTokens = int(ct.Int())
|
||||
} else if ot := usage.Get("output_tokens"); ot.Exists() {
|
||||
outputTokens = int(ot.Int())
|
||||
}
|
||||
|
||||
if ct := usage.Get("cache_read_input_tokens"); ct.Exists() {
|
||||
cachedTokens = int(ct.Int())
|
||||
}
|
||||
}
|
||||
|
||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||
if timings.Exists() {
|
||||
inputTokens = int(timings.Get("prompt_n").Int())
|
||||
outputTokens = int(timings.Get("predicted_n").Int())
|
||||
promptPerSecond = timings.Get("prompt_per_second").Float()
|
||||
tokensPerSecond = timings.Get("predicted_per_second").Float()
|
||||
durationMs = int(timings.Get("prompt_ms").Float() + timings.Get("predicted_ms").Float())
|
||||
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = int(cachedValue.Int())
|
||||
}
|
||||
}
|
||||
|
||||
return TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
CachedTokens: cachedTokens,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
PromptPerSecond: promptPerSecond,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
DurationMs: durationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
tee io.Writer
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||
bodyBuffer := &bytes.Buffer{}
|
||||
return &responseBodyCopier{
|
||||
ResponseWriter: w,
|
||||
body: bodyBuffer,
|
||||
tee: io.MultiWriter(w, bodyBuffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||
if w.start.IsZero() {
|
||||
w.start = time.Now()
|
||||
}
|
||||
|
||||
// Single write operation that writes to both the response and buffer
|
||||
return w.tee.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// sensitiveHeaders lists headers that should be redacted in captures
|
||||
var sensitiveHeaders = map[string]bool{
|
||||
"authorization": true,
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"set-cookie": true,
|
||||
"x-api-key": true,
|
||||
}
|
||||
|
||||
// redactHeaders replaces sensitive header values in-place with "[REDACTED]"
|
||||
func redactHeaders(headers map[string]string) {
|
||||
for key := range headers {
|
||||
if sensitiveHeaders[strings.ToLower(key)] {
|
||||
headers[key] = "[REDACTED]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding := strings.TrimSpace(strings.Split(part, ";")[0])
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,141 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
// Create a shared transport with reasonable timeouts for peer connections
|
||||
// these can be tuned with feedback later
|
||||
peerTransport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second, // Connection timeout
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 60 * time.Second, // Time to wait for response headers
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
+761
-218
File diff suppressed because it is too large
Load Diff
+419
-14
@@ -2,24 +2,38 @@ package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
debugLogger = NewLogMonitorWriter(os.Stdout)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// flip to help with debugging tests
|
||||
if false {
|
||||
debugLogger.SetLogLevel(LevelDebug)
|
||||
} else {
|
||||
debugLogger.SetLogLevel(LevelError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// Create a process
|
||||
process := NewProcess("test-process", 5, config, logMonitor)
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -48,26 +62,55 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests
|
||||
// are all handled successfully, even though they all may ask for the process to .start()
|
||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID)
|
||||
assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// test that the automatic start returns the expected error type
|
||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||
// Create a process configuration
|
||||
config := ModelConfig{
|
||||
config := config.ModelConfig{
|
||||
Cmd: "nonexistent-command",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||
defer process.Stop()
|
||||
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||
}
|
||||
|
||||
// test that the process unloads after the TTL
|
||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long auto unload TTL test")
|
||||
@@ -79,7 +122,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -111,7 +154,36 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
assert.Equal(t, StateStopped, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_LowTTLValue(t *testing.T) {
|
||||
if true { // change this code to run this ...
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
config := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Logf("Waiting before sending request %d", i)
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
expected := fmt.Sprintf("echo=test_%d", i)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// issue #19
|
||||
// This test makes sure using Process.Stop() does not affect pending HTTP
|
||||
// requests. All HTTP requests in this test should complete successfully.
|
||||
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
@@ -119,7 +191,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
|
||||
expectedMessage := "12345"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
||||
process := NewProcess("t", 10, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
results := map[string]string{
|
||||
@@ -135,8 +207,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
// send a request that should take 5 * 200ms (1 second) to complete
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
|
||||
// send a request where simple-responder is will wait 300ms before responding
|
||||
// this will simulate an in-progress request.
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
process.ProxyRequest(w, req)
|
||||
@@ -152,9 +225,9 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
}(key)
|
||||
}
|
||||
|
||||
// stop the requests in the middle
|
||||
// Stop the process while requests are still being processed
|
||||
go func() {
|
||||
<-time.After(500 * time.Millisecond)
|
||||
<-time.After(150 * time.Millisecond)
|
||||
process.Stop()
|
||||
}()
|
||||
|
||||
@@ -164,3 +237,335 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_SwapState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentState ProcessState
|
||||
expectedState ProcessState
|
||||
newState ProcessState
|
||||
expectedError error
|
||||
expectedResult ProcessState
|
||||
}{
|
||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
|
||||
p.state = test.currentState
|
||||
|
||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||
if err != nil && test.expectedError == nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && test.expectedError != nil {
|
||||
t.Errorf("Expected error: %v, but got none", test.expectedError)
|
||||
} else if err != nil && test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("Expected error: %v, got: %v", test.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
if resultState != test.expectedResult {
|
||||
t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long shutdown test")
|
||||
}
|
||||
|
||||
expectedMessage := "testing91931"
|
||||
|
||||
// make a config where the healthcheck will always fail because port is wrong
|
||||
config := getTestSimpleResponderConfigPort(expectedMessage, 9999)
|
||||
config.Proxy = "http://localhost:9998/test"
|
||||
|
||||
healthCheckTTLSeconds := 30
|
||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
|
||||
|
||||
// make it a lot faster
|
||||
process.healthCheckLoopInterval = time.Second
|
||||
|
||||
// start a goroutine to simulate a shutdown
|
||||
var wg sync.WaitGroup
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-time.After(time.Millisecond * 500)
|
||||
process.Shutdown()
|
||||
}()
|
||||
wg.Add(1)
|
||||
|
||||
// start the process, this is a blocking call
|
||||
err := process.start()
|
||||
|
||||
wg.Wait()
|
||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||
}
|
||||
|
||||
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping Exit Interrupts Health Check test")
|
||||
}
|
||||
|
||||
// should run and exit but interrupt the long checkHealthTimeout
|
||||
checkHealthTimeout := 5
|
||||
config := config.ModelConfig{
|
||||
Cmd: "sleep 1",
|
||||
Proxy: "http://127.0.0.1:9913",
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||
process.healthCheckLoopInterval = time.Second // make it faster
|
||||
err := process.start()
|
||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping long concurrency limit test")
|
||||
}
|
||||
|
||||
expectedMessage := "concurrency_limit_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// only allow 1 concurrent request at a time
|
||||
config.ConcurrencyLimit = 1
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||
defer process.Stop()
|
||||
|
||||
// launch a goroutine first to take up the semaphore
|
||||
go func() {
|
||||
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req1)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}()
|
||||
|
||||
// let the goroutine start
|
||||
<-time.After(time.Millisecond * 25)
|
||||
|
||||
denied := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, denied)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
|
||||
func TestProcess_StopImmediately(t *testing.T) {
|
||||
expectedMessage := "test_stop_immediate"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
}()
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
|
||||
expectedMessage := "test_sigkill"
|
||||
binaryPath := getSimpleResponderPath()
|
||||
port := getTestPort()
|
||||
|
||||
conf := config.ModelConfig{
|
||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||
// to force the process to exit
|
||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
CheckEndpoint: "/health",
|
||||
}
|
||||
|
||||
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// reduce to make testing go faster
|
||||
process.gracefulStopTimeout = time.Second
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
|
||||
waitChan := make(chan struct{})
|
||||
go func() {
|
||||
// slow, but will get killed by StopImmediate
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||
w := httptest.NewRecorder()
|
||||
process.ProxyRequest(w, req)
|
||||
|
||||
// StatusOK because that was already sent before the kill
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||
// then the unexpected EOF is sent after the kill
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||
} else {
|
||||
// Upstream may be killed mid-response.
|
||||
// Assert an incomplete or partial response.
|
||||
assert.NotEqual(t, "12345", w.Body.String())
|
||||
}
|
||||
|
||||
close(waitChan)
|
||||
}()
|
||||
|
||||
<-time.After(time.Millisecond)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
|
||||
// the request should have been interrupted by SIGKILL
|
||||
<-waitChan
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
conf.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||
expectedMessage := "test_env_not_emptied"
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
// ensure that the the default config does not blank out the inherited environment
|
||||
configWEnv := conf
|
||||
|
||||
// ensure the additiona variables are appended to the process' environment
|
||||
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||
|
||||
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||
|
||||
process1.start()
|
||||
defer process1.Stop()
|
||||
process2.start()
|
||||
defer process2.Stop()
|
||||
|
||||
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||
|
||||
}
|
||||
|
||||
// TestProcess_ReverseProxyPanicIsHandled tests that panics from
|
||||
// httputil.ReverseProxy in Process.ProxyRequest(w, r) do not bubble up and are
|
||||
// handled appropriately.
|
||||
//
|
||||
// httputil.ReverseProxy will panic with http.ErrAbortHandler when it has sent headers
|
||||
// can't copy the body. This can be caused by a client disconnecting before the full
|
||||
// response is sent from some reason.
|
||||
//
|
||||
// bug: https://github.com/mostlygeek/llama-swap/issues/362
|
||||
// see: https://github.com/golang/go/issues/23643 (where panic was added to httputil.ReverseProxy)
|
||||
func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) {
|
||||
// Add defer/recover to catch any panics that aren't handled by ProxyRequest
|
||||
// If this recover() is hit, it means ProxyRequest didn't handle the panic properly
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("ProxyRequest should handle panics from reverseProxy.ServeHTTP, but panic was not caught: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
expectedMessage := "panic_test"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
|
||||
process := NewProcess("panic-test", 5, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// Start the process
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
|
||||
// Create a custom ResponseWriter that simulates a client disconnect
|
||||
// by panicking when Write is called after headers are sent
|
||||
panicWriter := &panicOnWriteResponseWriter{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
shouldPanic: true,
|
||||
}
|
||||
|
||||
// Make a request that will trigger the panic
|
||||
req := httptest.NewRequest("GET", "/slow-respond?echo=test&delay=100ms", nil)
|
||||
|
||||
// This should panic inside reverseProxy.ServeHTTP when the panicWriter.Write() is called.
|
||||
// ProxyRequest should catch and handle this panic gracefully.
|
||||
process.ProxyRequest(panicWriter, req)
|
||||
|
||||
// If we get here, the panic was properly recovered in ProxyRequest
|
||||
// The process should still be in a ready state
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
|
||||
// panicOnWriteResponseWriter is a ResponseWriter that panics on Write
|
||||
// to simulate a client disconnect after headers are sent
|
||||
// used by: TestProcess_ReverseProxyPanicIsHandled
|
||||
type panicOnWriteResponseWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
shouldPanic bool
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) WriteHeader(statusCode int) {
|
||||
w.headerWritten = true
|
||||
w.ResponseRecorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.shouldPanic && w.headerWritten {
|
||||
// Simulate the panic that httputil.ReverseProxy throws
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
return w.ResponseRecorder.Write(b)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
// No-op on Unix systems
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttributes sets platform-specific process attributes
|
||||
func setProcAttributes(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: 0x08000000, // CREATE_NO_WINDOW
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type ProcessGroup struct {
|
||||
sync.Mutex
|
||||
|
||||
config config.Config
|
||||
id string
|
||||
swap bool
|
||||
exclusive bool
|
||||
persistent bool
|
||||
|
||||
proxyLogger *LogMonitor
|
||||
upstreamLogger *LogMonitor
|
||||
|
||||
// map of current processes
|
||||
processes map[string]*Process
|
||||
lastUsedProcess string
|
||||
}
|
||||
|
||||
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||
groupConfig, ok := config.Groups[id]
|
||||
if !ok {
|
||||
panic("Unable to find configuration for group id: " + id)
|
||||
}
|
||||
|
||||
pg := &ProcessGroup{
|
||||
id: id,
|
||||
config: config,
|
||||
swap: groupConfig.Swap,
|
||||
exclusive: groupConfig.Exclusive,
|
||||
persistent: groupConfig.Persistent,
|
||||
proxyLogger: proxyLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
processes: make(map[string]*Process),
|
||||
}
|
||||
|
||||
// Create a Process for each member in the group
|
||||
for _, modelID := range groupConfig.Members {
|
||||
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||
processLogger := NewLogMonitorWriter(upstreamLogger)
|
||||
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
|
||||
pg.processes[modelID] = process
|
||||
}
|
||||
|
||||
return pg
|
||||
}
|
||||
|
||||
// ProxyRequest proxies a request to the specified model
|
||||
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||
if !pg.HasMember(modelID) {
|
||||
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||
}
|
||||
|
||||
if pg.swap {
|
||||
pg.Lock()
|
||||
if pg.lastUsedProcess != modelID {
|
||||
|
||||
// is there something already running?
|
||||
if pg.lastUsedProcess != "" {
|
||||
pg.processes[pg.lastUsedProcess].Stop()
|
||||
}
|
||||
|
||||
// wait for the request to the new model to be fully handled
|
||||
// and prevent race conditions see issue #277
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
pg.lastUsedProcess = modelID
|
||||
|
||||
// short circuit and exit
|
||||
pg.Unlock()
|
||||
return nil
|
||||
}
|
||||
pg.Unlock()
|
||||
}
|
||||
|
||||
pg.processes[modelID].ProxyRequest(writer, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
|
||||
if pg.HasMember(modelName) {
|
||||
return pg.processes[modelName], true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
process, exists := pg.processes[modelID]
|
||||
if !exists {
|
||||
pg.Unlock()
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
if pg.lastUsedProcess == modelID {
|
||||
pg.lastUsedProcess = ""
|
||||
}
|
||||
pg.Unlock()
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
if len(pg.processes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// stop Processes in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) Shutdown() {
|
||||
var wg sync.WaitGroup
|
||||
for _, process := range pg.processes {
|
||||
wg.Add(1)
|
||||
go func(process *Process) {
|
||||
defer wg.Done()
|
||||
process.Shutdown()
|
||||
}(process)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
"model4": getTestSimpleResponderConfig("model4"),
|
||||
"model5": getTestSimpleResponderConfig("model5"),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Exclusive: true,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
"G2": {
|
||||
Swap: false,
|
||||
Exclusive: true,
|
||||
Members: []string{"model3", "model4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model5"))
|
||||
}
|
||||
|
||||
func TestProcessGroup_HasMember(t *testing.T) {
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
assert.True(t, pg.HasMember("model1"))
|
||||
assert.True(t, pg.HasMember("model2"))
|
||||
assert.False(t, pg.HasMember("model3"))
|
||||
}
|
||||
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
// use the same listening so if a model is already running, it will fail
|
||||
// this is a way to test that swap isolation is working
|
||||
// properly when there are parallel requests made at the
|
||||
// same time.
|
||||
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||
},
|
||||
Groups: map[string]config.GroupConfig{
|
||||
"G1": {
|
||||
Swap: true,
|
||||
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(len(tests))
|
||||
for _, modelName := range tests {
|
||||
go func(modelName string) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
}(modelName)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
tests := []string{"model3", "model4"}
|
||||
|
||||
for _, modelName := range tests {
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
reqBody := `{"x", "y"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), modelName)
|
||||
})
|
||||
}
|
||||
|
||||
// make sure all the processes are running
|
||||
for _, process := range pg.processes {
|
||||
assert.Equal(t, StateReady, process.CurrentState())
|
||||
}
|
||||
}
|
||||
+956
-200
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,294 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
// Protected with API key authentication
|
||||
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
apiGroup.GET("/version", pm.apiGetVersion)
|
||||
apiGroup.GET("/captures/:id", pm.apiGetCapture)
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.JSON(http.StatusOK, gin.H{"msg": "ok"})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) getModelStatus() []Model {
|
||||
// Extract keys and sort them
|
||||
models := []Model{}
|
||||
|
||||
modelIDs := make([]string, 0, len(pm.config.Models))
|
||||
for modelID := range pm.config.Models {
|
||||
modelIDs = append(modelIDs, modelID)
|
||||
}
|
||||
sort.Strings(modelIDs)
|
||||
|
||||
// Iterate over sorted keys
|
||||
for _, modelID := range modelIDs {
|
||||
// Get process state
|
||||
processGroup := pm.findGroupByModelName(modelID)
|
||||
state := "unknown"
|
||||
if processGroup != nil {
|
||||
process := processGroup.processes[modelID]
|
||||
if process != nil {
|
||||
var stateStr string
|
||||
switch process.CurrentState() {
|
||||
case StateReady:
|
||||
stateStr = "ready"
|
||||
case StateStarting:
|
||||
stateStr = "starting"
|
||||
case StateStopping:
|
||||
stateStr = "stopping"
|
||||
case StateShutdown:
|
||||
stateStr = "shutdown"
|
||||
case StateStopped:
|
||||
stateStr = "stopped"
|
||||
default:
|
||||
stateStr = "unknown"
|
||||
}
|
||||
state = stateStr
|
||||
}
|
||||
}
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
Name: pm.config.Models[modelID].Name,
|
||||
Description: pm.config.Models[modelID].Description,
|
||||
State: state,
|
||||
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||
Aliases: pm.config.Models[modelID].Aliases,
|
||||
})
|
||||
}
|
||||
|
||||
// Iterate over the peer models
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
PeerID: peerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
type messageType string
|
||||
|
||||
const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
msgTypeInFlight messageType = "inflight"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
Type messageType `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// sends a stream of different message types that happen on the server
|
||||
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering SSE
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
sendBuffer := make(chan messageEnvelope, 25)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
sendModels := func() {
|
||||
data, err := json.Marshal(pm.getModelStatus())
|
||||
if err == nil {
|
||||
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||
select {
|
||||
case sendBuffer <- msg:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
sendLogData := func(source string, data []byte) {
|
||||
data, err := json.Marshal(gin.H{
|
||||
"source": source,
|
||||
"data": string(data),
|
||||
})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sendMetrics := func(metrics []TokenMetrics) {
|
||||
jsonData, err := json.Marshal(metrics)
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sendInFlight := func(total int) {
|
||||
jsonData, err := json.Marshal(gin.H{"total": total})
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeInFlight, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
defer event.On(func(e ProcessStateChangeEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
defer event.On(func(e ConfigFileChangedEvent) {
|
||||
sendModels()
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Log data
|
||||
*/
|
||||
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("proxy", data)
|
||||
})()
|
||||
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||
sendLogData("upstream", data)
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Metrics data
|
||||
*/
|
||||
defer event.On(func(e TokenMetricsEvent) {
|
||||
sendMetrics([]TokenMetrics{e.Metrics})
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send in-flight request stats related to token stats "Waiting: N" count.
|
||||
*/
|
||||
defer event.On(func(e InFlightRequestsEvent) {
|
||||
sendInFlight(e.Total)
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||
sendInFlight(pm.inFlightCounter.Current())
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case msg := <-sendBuffer:
|
||||
c.SSEvent("message", msg)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
jsonData, err := pm.metricsMonitor.getMetricsJSON()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, map[string]string{
|
||||
"version": pm.version,
|
||||
"commit": pm.commit,
|
||||
"build_date": pm.buildDate,
|
||||
})
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid capture ID"})
|
||||
return
|
||||
}
|
||||
|
||||
capture := pm.metricsMonitor.getCaptureByID(id)
|
||||
if capture == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, capture)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -9,26 +10,12 @@ import (
|
||||
)
|
||||
|
||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "text/html") {
|
||||
// Set the Content-Type header to text/html
|
||||
c.Header("Content-Type", "text/html")
|
||||
|
||||
// Write the embedded HTML content to the response
|
||||
logsHTML, err := getHTMLFile("logs.html")
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_, err = c.Writer.Write(logsHTML)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusFound, "/ui/")
|
||||
} else {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
history := pm.logMonitor.GetHistory()
|
||||
history := pm.muxLogger.GetHistory()
|
||||
_, err := c.Writer.Write(history)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, err)
|
||||
@@ -41,11 +28,16 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/plain")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
// prevent nginx from buffering streamed logs
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
|
||||
logger, err := pm.getLogger(logMonitorId)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||
@@ -56,58 +48,60 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
||||
// Send history first if not skipped
|
||||
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
history := logger.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.Writer.Write(history)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
sendChan := make(chan []byte, 10)
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer logger.OnLogData(func(data []byte) {
|
||||
select {
|
||||
case sendChan <- data:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
})()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
_, err := c.Writer.Write(msg)
|
||||
if err != nil {
|
||||
// just break the loop if we can't write for some reason
|
||||
return
|
||||
}
|
||||
case <-c.Request.Context().Done():
|
||||
cancel()
|
||||
return
|
||||
case <-pm.shutdownCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case data := <-sendChan:
|
||||
c.Writer.Write(data)
|
||||
flusher.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
ch := pm.logMonitor.Subscribe()
|
||||
defer pm.logMonitor.Unsubscribe(ch)
|
||||
|
||||
notify := c.Request.Context().Done()
|
||||
|
||||
// Send history first if not skipped
|
||||
_, skipHistory := c.GetQuery("no-history")
|
||||
if !skipHistory {
|
||||
history := pm.logMonitor.GetHistory()
|
||||
if len(history) != 0 {
|
||||
c.SSEvent("message", string(history))
|
||||
c.Writer.Flush()
|
||||
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||
switch logMonitorId {
|
||||
case "":
|
||||
// maintain the default
|
||||
return pm.muxLogger, nil
|
||||
case "proxy":
|
||||
return pm.proxyLogger, nil
|
||||
case "upstream":
|
||||
return pm.upstreamLogger, nil
|
||||
default:
|
||||
// search for a models specific logger using findModelInPath
|
||||
// to handle model names with slashes (e.g., "author/model")
|
||||
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
|
||||
for _, group := range pm.processGroups {
|
||||
if process, found := group.GetMember(name); found {
|
||||
return process.Logger(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream new logs
|
||||
for {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
c.SSEvent("message", string(msg))
|
||||
c.Writer.Flush()
|
||||
case <-notify:
|
||||
return
|
||||
}
|
||||
return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
|
||||
}
|
||||
}
|
||||
|
||||
+1513
-64
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,43 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isTokenChar(r rune) bool {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||
parts := strings.Split(headerValues, ",")
|
||||
valid := make([]string, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
validPart := true
|
||||
for _, c := range v {
|
||||
if !isTokenChar(c) {
|
||||
validPart = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if validPart {
|
||||
valid = append(valid, v)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(valid, ", ")
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single valid value",
|
||||
input: "content-type",
|
||||
expected: "content-type",
|
||||
},
|
||||
{
|
||||
name: "multiple valid values",
|
||||
input: "content-type, authorization, x-requested-with",
|
||||
expected: "content-type, authorization, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "values with extra spaces",
|
||||
input: " content-type , authorization ",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with tabs",
|
||||
input: "content-type,\tauthorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "values with invalid characters",
|
||||
input: "content-type, auth\n, x-requested-with\r",
|
||||
expected: "content-type, auth, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "empty values in list",
|
||||
input: "content-type,,authorization",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing commas",
|
||||
input: ",content-type,authorization,",
|
||||
expected: "content-type, authorization",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid values",
|
||||
input: "content-type, \x00invalid, x-requested-with",
|
||||
expected: "content-type, x-requested-with",
|
||||
},
|
||||
{
|
||||
name: "mixed case values",
|
||||
input: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
||||
tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// selectEncoding chooses the best encoding based on Accept-Encoding header
|
||||
// Returns the encoding ("br", "gzip", or "") and the corresponding file extension
|
||||
func selectEncoding(acceptEncoding string) (encoding, ext string) {
|
||||
if acceptEncoding == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "br" {
|
||||
return "br", ".br"
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
enc := strings.TrimSpace(strings.SplitN(part, ";", 2)[0])
|
||||
if enc == "gzip" {
|
||||
return "gzip", ".gz"
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// ServeCompressedFile serves a file with compression support.
|
||||
// It checks for pre-compressed versions and serves them with proper headers.
|
||||
func ServeCompressedFile(fs http.FileSystem, w http.ResponseWriter, r *http.Request, name string) {
|
||||
encoding, ext := selectEncoding(r.Header.Get("Accept-Encoding"))
|
||||
|
||||
// Try to serve compressed version if client supports it
|
||||
if encoding != "" {
|
||||
if cf, err := fs.Open(name + ext); err == nil {
|
||||
defer cf.Close()
|
||||
|
||||
// Verify it's a regular file (not a directory)
|
||||
if stat, err := cf.Stat(); err == nil && !stat.IsDir() {
|
||||
// Set the content encoding header
|
||||
w.Header().Set("Content-Encoding", encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
|
||||
// Get original file info for content type detection
|
||||
origFile, err := fs.Open(name)
|
||||
if err == nil {
|
||||
origFile.Close()
|
||||
}
|
||||
|
||||
// Serve the compressed file
|
||||
http.ServeContent(w, r, name, stat.ModTime(), cf)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to serving the uncompressed file
|
||||
file, err := fs.Open(name)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
http.Error(w, "is a directory", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, name, stat.ModTime(), file)
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServeCompressedFile_Brotli(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with brotli")
|
||||
brContent := []byte("fake-brotli-compressed-data")
|
||||
|
||||
// Create a test filesystem
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: brContent, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("fake-gzip-data"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check that brotli is used (preferred over gzip)
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if vary := resp.Header.Get("Vary"); vary != "Accept-Encoding" {
|
||||
t.Errorf("Expected Vary 'Accept-Encoding', got '%s'", vary)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, brContent) {
|
||||
t.Errorf("Expected brotli content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_Gzip(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content that should be compressed with gzip")
|
||||
gzContent := []byte("fake-gzip-compressed-data")
|
||||
|
||||
// Create a test filesystem without brotli
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.gz": {Data: gzContent, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, gzContent) {
|
||||
t.Errorf("Expected gzip content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_UncompressedFallback(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is uncompressed test content")
|
||||
|
||||
// Create a test filesystem without compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
req.Header.Set("Accept-Encoding", "br, gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should not have Content-Encoding header since we're serving uncompressed
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NoAcceptEncoding(t *testing.T) {
|
||||
// Create test content
|
||||
content := []byte("This is test content")
|
||||
|
||||
// Create a test filesystem with compressed versions
|
||||
mapFS := fstest.MapFS{
|
||||
"test.js": {Data: content, ModTime: time.Now()},
|
||||
"test.js.br": {Data: []byte("brotli"), ModTime: time.Now()},
|
||||
"test.js.gz": {Data: []byte("gzip"), ModTime: time.Now()},
|
||||
}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test.js", nil)
|
||||
// No Accept-Encoding header
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "test.js")
|
||||
|
||||
resp := w.Result()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Should serve uncompressed content
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
t.Errorf("Expected no Content-Encoding, got '%s'", encoding)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, content) {
|
||||
t.Errorf("Expected original content, got %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeCompressedFile_NotFound(t *testing.T) {
|
||||
mapFS := fstest.MapFS{}
|
||||
fs := http.FS(mapFS)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/nonexistent.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, "nonexistent.js")
|
||||
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectEncoding(t *testing.T) {
|
||||
tests := []struct {
|
||||
acceptEncoding string
|
||||
wantEncoding string
|
||||
wantExt string
|
||||
}{
|
||||
{"br, gzip", "br", ".br"},
|
||||
{"gzip, deflate", "gzip", ".gz"},
|
||||
{"gzip", "gzip", ".gz"},
|
||||
{"br", "br", ".br"},
|
||||
{"", "", ""},
|
||||
{"deflate", "", ""},
|
||||
{"br;q=1.0, gzip;q=0.5", "br", ".br"},
|
||||
{"gzip;q=1.0, br;q=0.5", "br", ".br"},
|
||||
{"browser", "", ""},
|
||||
{"compress, deflate", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
gotEncoding, gotExt := selectEncoding(tt.acceptEncoding)
|
||||
if gotEncoding != tt.wantEncoding || gotExt != tt.wantExt {
|
||||
t.Errorf("selectEncoding(%q) = (%q, %q), want (%q, %q)",
|
||||
tt.acceptEncoding, gotEncoding, gotExt, tt.wantEncoding, tt.wantExt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual pre-compressed files from ui_dist
|
||||
func TestServeCompressedFile_RealFiles(t *testing.T) {
|
||||
// Check if ui_dist exists
|
||||
if _, err := os.Stat("./ui_dist"); os.IsNotExist(err) {
|
||||
t.Skip("ui_dist not found, skipping real file test")
|
||||
}
|
||||
|
||||
// Find a .js or .css file that has compressed versions
|
||||
entries, err := os.ReadDir("./ui_dist/assets")
|
||||
if err != nil {
|
||||
t.Skipf("Could not read ui_dist/assets: %v", err)
|
||||
}
|
||||
|
||||
var testFile string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, ".js") && !strings.HasSuffix(name, ".js.gz") && !strings.HasSuffix(name, ".js.br") {
|
||||
// Check if compressed versions exist
|
||||
base := strings.TrimSuffix(name, ".js")
|
||||
if _, err := os.Stat(filepath.Join("./ui_dist/assets", base+".js.gz")); err == nil {
|
||||
testFile = "assets/" + name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if testFile == "" {
|
||||
t.Skip("No suitable test file found with compressed versions")
|
||||
}
|
||||
|
||||
fs := http.FS(os.DirFS("./ui_dist"))
|
||||
|
||||
// Test brotli
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "br" {
|
||||
t.Errorf("Expected Content-Encoding 'br', got '%s'", encoding)
|
||||
}
|
||||
})
|
||||
|
||||
// Test gzip
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+testFile, nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ServeCompressedFile(fs, w, req, testFile)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "gzip" {
|
||||
t.Errorf("Expected Content-Encoding 'gzip', got '%s'", encoding)
|
||||
}
|
||||
|
||||
// Verify it's valid gzip
|
||||
reader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid gzip content: %v", err)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Just read to verify it's valid
|
||||
_, err = io.Copy(io.Discard, reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decompress gzip: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
//go:embed ui_dist
|
||||
var reactStaticFS embed.FS
|
||||
|
||||
// GetReactFS returns the embedded React filesystem
|
||||
func GetReactFS() (http.FileSystem, error) {
|
||||
subFS, err := fs.Sub(reactStaticFS, "ui_dist")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return http.FS(subFS), nil
|
||||
}
|
||||
|
||||
// GetReactIndexHTML returns the main index.html for the React app
|
||||
func GetReactIndexHTML() ([]byte, error) {
|
||||
return reactStaticFS.ReadFile("ui_dist/index.html")
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
#!/bin/sh
|
||||
# This script installs llama-swap on Linux.
|
||||
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
|
||||
|
||||
set -eu
|
||||
|
||||
LLAMA_SWAP_DEFAULT_ADDRESS=${LLAMA_SWAP_DEFAULT_ADDRESS:-"127.0.0.1:8080"}
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||
|
||||
status() { echo ">>> $*" >&2; }
|
||||
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||
|
||||
available() { command -v "$1" >/dev/null; }
|
||||
require() {
|
||||
_MISSING=''
|
||||
for TOOL in "$@"; do
|
||||
if ! available "$TOOL"; then
|
||||
_MISSING="$_MISSING $TOOL"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "$_MISSING"
|
||||
}
|
||||
|
||||
SUDO=
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
if ! available sudo; then
|
||||
error "This script requires superuser permissions. Please re-run as root."
|
||||
fi
|
||||
|
||||
SUDO="sudo"
|
||||
fi
|
||||
|
||||
NEEDS=$(require tee tar python3 mktemp)
|
||||
if [ -n "$NEEDS" ]; then
|
||||
status "ERROR: The following tools are required but missing:"
|
||||
for NEED in $NEEDS; do
|
||||
echo " - $NEED"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
aarch64|arm64) ARCH="arm64" ;;
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
esac
|
||||
|
||||
IS_WSL2=false
|
||||
|
||||
KERN=$(uname -r)
|
||||
case "$KERN" in
|
||||
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
|
||||
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||
*) ;;
|
||||
esac
|
||||
|
||||
download_binary() {
|
||||
ASSET_NAME="linux_$ARCH"
|
||||
|
||||
TMPDIR=$(mktemp -d)
|
||||
trap 'rm -rf "${TMPDIR}"' EXIT INT TERM HUP
|
||||
PYTHON_SCRIPT=$(cat <<EOF
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
ASSET_NAME = "${ASSET_NAME}"
|
||||
|
||||
with urllib.request.urlopen("https://api.github.com/repos/mostlygeek/llama-swap/releases/latest") as resp:
|
||||
data = json.load(resp)
|
||||
for asset in data.get("assets", []):
|
||||
if ASSET_NAME in asset.get("name", ""):
|
||||
url = asset["browser_download_url"]
|
||||
break
|
||||
else:
|
||||
print("ERROR: Matching asset not found.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
print("Downloading:", url, file=sys.stderr)
|
||||
output_path = os.path.join("${TMPDIR}", "llama-swap.tar.gz")
|
||||
urllib.request.urlretrieve(url, output_path)
|
||||
print(output_path)
|
||||
EOF
|
||||
)
|
||||
|
||||
TARFILE=$(python3 -c "$PYTHON_SCRIPT")
|
||||
if [ ! -f "$TARFILE" ]; then
|
||||
error "Failed to download binary."
|
||||
fi
|
||||
|
||||
status "Extracting to /usr/local/bin"
|
||||
$SUDO tar -xzf "$TARFILE" -C /usr/local/bin llama-swap
|
||||
}
|
||||
download_binary
|
||||
|
||||
configure_systemd() {
|
||||
if ! id llama-swap >/dev/null 2>&1; then
|
||||
status "Creating llama-swap user..."
|
||||
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
|
||||
fi
|
||||
if getent group render >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to render group..."
|
||||
$SUDO usermod -a -G render llama-swap
|
||||
fi
|
||||
if getent group video >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to video group..."
|
||||
$SUDO usermod -a -G video llama-swap
|
||||
fi
|
||||
if getent group docker >/dev/null 2>&1; then
|
||||
status "Adding llama-swap user to docker group..."
|
||||
$SUDO usermod -a -G docker llama-swap
|
||||
fi
|
||||
|
||||
status "Adding current user to llama-swap group..."
|
||||
$SUDO usermod -a -G llama-swap "$(whoami)"
|
||||
|
||||
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||
status "Creating default config.yaml..."
|
||||
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
|
||||
# default 15s likely to fail for default models due to downloading models
|
||||
healthCheckTimeout: 60
|
||||
|
||||
models:
|
||||
"qwen2.5":
|
||||
cmd: |
|
||||
docker run
|
||||
--rm
|
||||
-p \${PORT}:8080
|
||||
--name qwen2.5
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||
cmdStop: docker stop qwen2.5
|
||||
|
||||
"smollm2":
|
||||
cmd: |
|
||||
docker run
|
||||
--rm
|
||||
-p \${PORT}:8080
|
||||
--name smollm2
|
||||
ghcr.io/ggml-org/llama.cpp:server
|
||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||
cmdStop: docker stop smollm2
|
||||
EOF
|
||||
fi
|
||||
|
||||
status "Creating llama-swap systemd service..."
|
||||
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
|
||||
[Unit]
|
||||
Description=llama-swap
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
User=llama-swap
|
||||
Group=llama-swap
|
||||
|
||||
# set this to match your environment
|
||||
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config -listen ${LLAMA_SWAP_DEFAULT_ADDRESS}
|
||||
|
||||
Restart=on-failure
|
||||
RestartSec=3
|
||||
StartLimitBurst=3
|
||||
StartLimitInterval=30
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
|
||||
case $SYSTEMCTL_RUNNING in
|
||||
running|degraded)
|
||||
status "Enabling and starting llama-swap service..."
|
||||
$SUDO systemctl daemon-reload
|
||||
$SUDO systemctl enable llama-swap
|
||||
|
||||
start_service() { $SUDO systemctl restart llama-swap; }
|
||||
trap start_service EXIT
|
||||
;;
|
||||
*)
|
||||
warning "systemd is not running"
|
||||
if [ "$IS_WSL2" = true ]; then
|
||||
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
if available systemctl; then
|
||||
configure_systemd
|
||||
fi
|
||||
|
||||
install_success() {
|
||||
status "The llama-swap API is now available at http://${LLAMA_SWAP_DEFAULT_ADDRESS}"
|
||||
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
|
||||
status 'Install complete.'
|
||||
}
|
||||
|
||||
# WSL2 only supports GPUs via nvidia passthrough
|
||||
# so check for nvidia-smi to determine if GPU is available
|
||||
if [ "$IS_WSL2" = true ]; then
|
||||
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||
status "Nvidia GPU detected."
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
install_success
|
||||
@@ -0,0 +1,68 @@
|
||||
#!/bin/sh
|
||||
# This script uninstalls llama-swap on Linux.
|
||||
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
|
||||
|
||||
set -eu
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||
|
||||
status() { echo ">>> $*" >&2; }
|
||||
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||
|
||||
available() { command -v $1 >/dev/null; }
|
||||
|
||||
SUDO=
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
if ! available sudo; then
|
||||
error "This script requires superuser permissions. Please re-run as root."
|
||||
fi
|
||||
|
||||
SUDO="sudo"
|
||||
fi
|
||||
|
||||
configure_systemd() {
|
||||
status "Stopping llama-swap service..."
|
||||
$SUDO systemctl stop llama-swap
|
||||
|
||||
status "Disabling llama-swap service..."
|
||||
$SUDO systemctl disable llama-swap
|
||||
}
|
||||
if available systemctl; then
|
||||
configure_systemd
|
||||
fi
|
||||
|
||||
if available llama-swap; then
|
||||
status "Removing llama-swap binary..."
|
||||
$SUDO rm $(which llama-swap)
|
||||
fi
|
||||
|
||||
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||
while true; do
|
||||
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
|
||||
read answer
|
||||
case "$answer" in
|
||||
[Yy]* )
|
||||
$SUDO rm -r /usr/share/llama-swap
|
||||
break
|
||||
;;
|
||||
[Nn]* | "" )
|
||||
break
|
||||
;;
|
||||
* )
|
||||
echo "Invalid input. Please enter y or n."
|
||||
;;
|
||||
esac
|
||||
done
|
||||
fi
|
||||
|
||||
if id llama-swap >/dev/null 2>&1; then
|
||||
status "Removing llama-swap user..."
|
||||
$SUDO userdel llama-swap
|
||||
fi
|
||||
|
||||
if getent group llama-swap >/dev/null 2>&1; then
|
||||
status "Removing llama-swap group..."
|
||||
$SUDO groupdel llama-swap
|
||||
fi
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user