Compare commits
141 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a533aec736 | |||
| 97b17fc47d | |||
| 2457840698 | |||
| 7f55494151 | |||
| 831a90d3b0 | |||
| 977f1856bb | |||
| 52b329f7bc | |||
| 57803fd3aa | |||
| c55d0cc842 | |||
| 7acbaf4712 | |||
| fcc5ad135a | |||
| 305e5a0031 | |||
| 04fc67354a | |||
| 4662cf7699 | |||
| 5dc6b3e6d9 | |||
| 74c69f39ef | |||
| a186318892 | |||
| c4e4d5e1e9 | |||
| 7985e94ba4 | |||
| 74556c3a36 | |||
| 5c381e4b30 | |||
| 10569ed546 | |||
| 5b10b3c23f | |||
| 45ea792a3a | |||
| 1bc2802353 | |||
| 701476c0c4 | |||
| 5c63e0066c | |||
| 8be5073c51 | |||
| 6307bd3205 | |||
| 558a72de17 | |||
| dc42cf366d | |||
| ba0a81937a | |||
| 574fdfabb4 | |||
| 5172cb2e12 | |||
| 5672cb03fd | |||
| 0f583163f7 | |||
| 7905fa9ea3 | |||
| bbaf172956 | |||
| fd50932dbc | |||
| 8c693e7fcf | |||
| 8f2af26a41 | |||
| 01d4838fb3 | |||
| accd65294b | |||
| 7472a25864 | |||
| cce0bc6aa1 | |||
| 36e25125e8 | |||
| 9a54273d15 | |||
| 87dce5f8f6 | |||
| 307e619521 | |||
| 6299c1b874 | |||
| a906cd459b | |||
| 78b2bc3dbc | |||
| 6a058e4191 | |||
| 1921e570d7 | |||
| c867a6c9a2 | |||
| 3bd1b23ce0 | |||
| 10606abf89 | |||
| fefd14903d | |||
| 717d64e336 | |||
| 285191e655 | |||
| 4236cec03a | |||
| 756193d0dd | |||
| a6b2e930d8 | |||
| 9e02c22ff8 | |||
| 0bdbf2fdc1 | |||
| 49035e2e8e | |||
| 9963ae18bf | |||
| 2ae48c713b | |||
| 54c519e365 | |||
| 3fce9ee0e9 | |||
| 5899ae7966 | |||
| 591a9cdf4d | |||
| 9a3c656738 | |||
| 75015f82ea | |||
| cc33b6c270 | |||
| 4fa12a429c | |||
| 2dc0ca0663 | |||
| a84098d3b4 | |||
| 4d02ccd26a | |||
| dfd47eeac4 | |||
| 1ac6499c08 | |||
| 25f3dc25e7 | |||
| 8422e4e6a1 | |||
| 02ee29d881 | |||
| b2a891f8f4 | |||
| 8d2b568897 | |||
| fb44cf4e08 | |||
| 02aee4e86d | |||
| f45896d395 | |||
| f7e46a359f | |||
| c260907415 | |||
| b83a5fa291 | |||
| 6e2ff28d59 | |||
| a8b81f2799 | |||
| f9ee7156dc | |||
| 2d00120781 | |||
| afc9aef058 | |||
| d7b390df74 | |||
| 5025c2f1f3 | |||
| e3a0b013c1 | |||
| f5763a94a0 | |||
| 8ada72eb57 | |||
| 2441b383d3 | |||
| 25f251699c | |||
| 7f37bcc6eb | |||
| 519c3a4d22 | |||
| 9dc4bcb46c | |||
| cb876c143b | |||
| bc652709a5 | |||
| 9548931258 | |||
| 5c5a5da664 | |||
| aa9ef59aa5 | |||
| 09e52c0500 | |||
| ca9063ffbe | |||
| 21d7973d11 | |||
| cc450e9c5f | |||
| 27465fe053 | |||
| 9667989727 | |||
| d9a1ddea0d | |||
| e7ab024ca0 | |||
| 448ccae959 | |||
| ec0348e431 | |||
| 06eda7f591 | |||
| 5fad24c16f | |||
| 8404244fab | |||
| 712cd01081 | |||
| 1f7aa359b1 | |||
| b138d6cf25 | |||
| fb7c808082 | |||
| a7e640b0f7 | |||
| 593604dfdc | |||
| b8f888f864 | |||
| 192b2ae621 | |||
| b7f8cb5094 | |||
| a23da6eb57 | |||
| 4c3aa40564 | |||
| 84e2c07a7e | |||
| 680af28bcc | |||
| d94db42ffe | |||
| 93cd83c55c | |||
| 5565fca3ac |
@@ -0,0 +1,15 @@
|
|||||||
|
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||||
|
language: "en-US"
|
||||||
|
early_access: false
|
||||||
|
reviews:
|
||||||
|
profile: "chill"
|
||||||
|
request_changes_workflow: false
|
||||||
|
high_level_summary: true
|
||||||
|
poem: false
|
||||||
|
review_status: true
|
||||||
|
collapse_walkthrough: false
|
||||||
|
auto_review:
|
||||||
|
enabled: true
|
||||||
|
drafts: false
|
||||||
|
chat:
|
||||||
|
auto_reply: true
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
---
|
||||||
|
name: Bug Report
|
||||||
|
about: I found a defect
|
||||||
|
title: ''
|
||||||
|
labels: 'unconfirmed bug'
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> If you have questions about llama-swap please post in the Q&A in Discussions. Use bug reports when you've found a defect and wish to discuss a fix.
|
||||||
|
|
||||||
|
**Describe the bug**
|
||||||
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
|
**Expected behaviour**
|
||||||
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Operating system and version**
|
||||||
|
|
||||||
|
- OS: (linux, osx, windows, freebsd, etc)
|
||||||
|
- GPUs: (list architecture)
|
||||||
|
|
||||||
|
**My Configuration**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# copy / paste your configuration here
|
||||||
|
```
|
||||||
|
|
||||||
|
**Proxy Logs**
|
||||||
|
|
||||||
|
```
|
||||||
|
# copy / paste from /logs
|
||||||
|
```
|
||||||
|
|
||||||
|
**Upstream Logs**
|
||||||
|
|
||||||
|
```
|
||||||
|
# copy/paste from /logs
|
||||||
|
```
|
||||||
@@ -13,11 +13,11 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@v9
|
- uses: actions/stale@v9
|
||||||
with:
|
with:
|
||||||
days-before-issue-stale: 30
|
days-before-issue-stale: 14
|
||||||
days-before-issue-close: 14
|
days-before-issue-close: 14
|
||||||
stale-issue-label: "stale"
|
stale-issue-label: "stale"
|
||||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity."
|
||||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale."
|
||||||
days-before-pr-stale: -1
|
days-before-pr-stale: -1
|
||||||
days-before-pr-close: -1
|
days-before-pr-close: -1
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
name: Windows CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
# Allows manual triggering of the workflow
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
run-tests:
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# cache simple-responder to save the build time
|
||||||
|
- name: Restore Simple Responder
|
||||||
|
id: restore-simple-responder
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
|
# necessary for testing proxy/Process swapping
|
||||||
|
- name: Create simple-responder
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: make simple-responder-windows
|
||||||
|
|
||||||
|
- name: Save Simple Responder
|
||||||
|
# nothing new to save ... skip this step
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
id: save-simple-responder
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
|
- name: Test all
|
||||||
|
shell: bash
|
||||||
|
run: make test-all
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
# This workflow will build a golang project
|
name: Linux CI
|
||||||
|
|
||||||
name: CI
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -24,9 +22,33 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
go-version: '1.23'
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# Only run in this linux based runner
|
||||||
|
- name: Check Formatting
|
||||||
|
run: |
|
||||||
|
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||||
|
gofmt -l . | grep -v 'event/.*_test.go'
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
# cache simple-responder to save the build time
|
||||||
|
- name: Restore Simple Responder
|
||||||
|
id: restore-simple-responder
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
# necessary for testing proxy/Process swapping
|
# necessary for testing proxy/Process swapping
|
||||||
- name: Create simple-responder
|
- name: Create simple-responder
|
||||||
run: make simple-responder
|
run: make simple-responder
|
||||||
|
|
||||||
|
- name: Save Simple Responder
|
||||||
|
# nothing new to save ... skip this step
|
||||||
|
if: steps.restore-simple-responder.outputs.cache-hit != 'true'
|
||||||
|
id: save-simple-responder
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: ./build
|
||||||
|
key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }}
|
||||||
|
|
||||||
- name: Test all
|
- name: Test all
|
||||||
run: make test-all
|
run: make test-all
|
||||||
@@ -7,6 +7,10 @@ on:
|
|||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: 'Tag version to release (e.g. v144)'
|
||||||
|
required: true
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
@@ -20,9 +24,22 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||||
-
|
-
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
|
-
|
||||||
|
name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '23'
|
||||||
|
-
|
||||||
|
name: Install dependencies and build UI
|
||||||
|
run: |
|
||||||
|
cd ui
|
||||||
|
npm ci
|
||||||
|
npm run build
|
||||||
|
|
||||||
-
|
-
|
||||||
name: Run GoReleaser
|
name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v6
|
uses: goreleaser/goreleaser-action@v6
|
||||||
@@ -33,4 +50,30 @@ jobs:
|
|||||||
version: '~> v2'
|
version: '~> v2'
|
||||||
args: release --clean
|
args: release --clean
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
trigger-tap-update:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: goreleaser
|
||||||
|
steps:
|
||||||
|
- name: "Resolve tag to dispatch"
|
||||||
|
id: tag
|
||||||
|
run: |
|
||||||
|
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||||
|
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: "Trigger tap repository update"
|
||||||
|
uses: peter-evans/repository-dispatch@v2
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.TAP_REPO_PAT }}
|
||||||
|
repository: mostlygeek/homebrew-llama-swap
|
||||||
|
event-type: new-release
|
||||||
|
client-payload: |
|
||||||
|
{
|
||||||
|
"release": {
|
||||||
|
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,3 +4,4 @@ build/
|
|||||||
dist/
|
dist/
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
.dev/
|
||||||
|
|||||||
@@ -15,4 +15,18 @@ builds:
|
|||||||
- goos: freebsd
|
- goos: freebsd
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
|
|
||||||
|
archives:
|
||||||
|
- id: default
|
||||||
|
formats:
|
||||||
|
- tar.gz
|
||||||
|
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
|
builds_info:
|
||||||
|
group: root
|
||||||
|
owner: root
|
||||||
|
format_overrides:
|
||||||
|
# use zip format for windows
|
||||||
|
- goos: windows
|
||||||
|
formats:
|
||||||
|
- zip
|
||||||
@@ -19,24 +19,36 @@ all: mac linux simple-responder
|
|||||||
clean:
|
clean:
|
||||||
rm -rf $(BUILD_DIR)
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
test:
|
proxy/ui_dist/placeholder.txt:
|
||||||
go test -short -v ./proxy
|
mkdir -p proxy/ui_dist
|
||||||
|
touch $@
|
||||||
|
|
||||||
test-all:
|
test: proxy/ui_dist/placeholder.txt
|
||||||
go test -v ./proxy
|
go test -short -v -count=1 ./proxy
|
||||||
|
|
||||||
|
test-all: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -v -count=1 ./proxy
|
||||||
|
|
||||||
|
ui/node_modules:
|
||||||
|
cd ui && npm install
|
||||||
|
|
||||||
|
# build react UI
|
||||||
|
ui: ui/node_modules
|
||||||
|
cd ui && npm run build
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
mac:
|
mac: ui
|
||||||
@echo "Building Mac binary..."
|
@echo "Building Mac binary..."
|
||||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||||
|
|
||||||
# Build Linux binary
|
# Build Linux binary
|
||||||
linux:
|
linux: ui
|
||||||
@echo "Building Linux binary..."
|
@echo "Building Linux binary..."
|
||||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||||
|
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||||
|
|
||||||
# Build Windows binary
|
# Build Windows binary
|
||||||
windows:
|
windows: ui
|
||||||
@echo "Building Windows binary..."
|
@echo "Building Windows binary..."
|
||||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||||
|
|
||||||
@@ -46,6 +58,10 @@ simple-responder:
|
|||||||
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go
|
||||||
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go
|
||||||
|
|
||||||
|
simple-responder-windows:
|
||||||
|
@echo "Building simple responder for windows"
|
||||||
|
GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go
|
||||||
|
|
||||||
# Ensure build directory exists
|
# Ensure build directory exists
|
||||||
$(BUILD_DIR):
|
$(BUILD_DIR):
|
||||||
mkdir -p $(BUILD_DIR)
|
mkdir -p $(BUILD_DIR)
|
||||||
@@ -65,4 +81,4 @@ release:
|
|||||||
git tag "$$new_tag";
|
git tag "$$new_tag";
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean mac linux windows simple-responder
|
.PHONY: all clean ui mac linux windows simple-responder
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||

|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
# llama-swap
|
# llama-swap
|
||||||
|
|
||||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||||
|
|
||||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary, a provided docker images or Homebrew.
|
||||||
|
|
||||||
## Features:
|
## Features:
|
||||||
|
|
||||||
@@ -15,142 +18,88 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
|||||||
- `v1/completions`
|
- `v1/completions`
|
||||||
- `v1/chat/completions`
|
- `v1/chat/completions`
|
||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
- `v1/rerank`
|
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||||
|
- ✅ llama-server (llama.cpp) supported endpoints:
|
||||||
|
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||||
|
- `/infill` - for code infilling
|
||||||
|
- `/completion` - for completion endpoint
|
||||||
- ✅ llama-swap custom API endpoints
|
- ✅ llama-swap custom API endpoints
|
||||||
|
- `/ui` - web UI
|
||||||
- `/log` - remote log monitoring
|
- `/log` - remote log monitoring
|
||||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
- `/health` - just returns "OK"
|
||||||
|
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||||
- ✅ Docker and Podman support
|
- ✅ Reliable Docker and Podman support with `cmdStart` and `cmdStop`
|
||||||
- ✅ Full control over server settings per model
|
- ✅ Full control over server settings per model
|
||||||
|
- ✅ Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||||
|
|
||||||
## How does llama-swap work?
|
## How does llama-swap work?
|
||||||
|
|
||||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||||
|
|
||||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used.
|
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
llama-swap's configuration is purposefully simple.
|
llama-swap is managed entirely through a yaml configuration file.
|
||||||
|
|
||||||
|
It can be very minimal to start:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
models:
|
models:
|
||||||
"qwen2.5":
|
"qwen2.5":
|
||||||
proxy: "http://127.0.0.1:9999"
|
cmd: |
|
||||||
cmd: >
|
/path/to/llama-server
|
||||||
/app/llama-server
|
|
||||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
--port 9999
|
--port ${PORT}
|
||||||
|
|
||||||
"smollm2":
|
|
||||||
proxy: "http://127.0.0.1:9999"
|
|
||||||
cmd: >
|
|
||||||
/app/llama-server
|
|
||||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
|
||||||
--port 9999
|
|
||||||
```
|
```
|
||||||
|
|
||||||
<details>
|
However, there are many more capabilities that llama-swap supports:
|
||||||
<summary>But also very powerful ...</summary>
|
|
||||||
|
|
||||||
```yaml
|
- `groups` to run multiple models at once
|
||||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
- `ttl` to automatically unload models
|
||||||
# Default (and minimum) is 15 seconds
|
- `macros` for reusable snippets
|
||||||
healthCheckTimeout: 60
|
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||||
|
- `env` to pass custom environment variables to inference servers
|
||||||
|
- `cmdStop` for to gracefully stop Docker/Podman containers
|
||||||
|
- `useModelName` to override model names sent to upstream servers
|
||||||
|
- `healthCheckTimeout` to control model startup wait times
|
||||||
|
- `${PORT}` automatic port variables for dynamic port assignment
|
||||||
|
|
||||||
# Write HTTP logs (useful for troubleshooting), defaults to false
|
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
|
||||||
logRequests: true
|
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
## Web UI
|
||||||
models:
|
|
||||||
"llama":
|
|
||||||
# multiline for readability
|
|
||||||
cmd: >
|
|
||||||
llama-server --port 8999
|
|
||||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
|
||||||
|
|
||||||
# environment variables to pass to the command
|
llama-swap includes a real time web interface for monitoring logs and models:
|
||||||
env:
|
|
||||||
- "CUDA_VISIBLE_DEVICES=0"
|
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/adef4a8e-de0b-49db-885a-8f6dedae6799" />
|
||||||
proxy: http://127.0.0.1:8999
|
|
||||||
|
|
||||||
# aliases names to use this model for
|
The Activity Page shows recent requests:
|
||||||
aliases:
|
|
||||||
- "gpt-4o-mini"
|
|
||||||
- "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
# check this path for an HTTP 200 OK before serving requests
|
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||||
# default: /health to match llama.cpp
|
|
||||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
|
||||||
# until the model is ready
|
|
||||||
checkEndpoint: /custom-endpoint
|
|
||||||
|
|
||||||
# automatically unload the model after this many seconds
|
## Installation
|
||||||
# ttl values must be a value greater than 0
|
|
||||||
# default: 0 = never unload model
|
|
||||||
ttl: 60
|
|
||||||
|
|
||||||
# `useModelName` overrides the model name in the request
|
llama-swap can be installed in multiple ways
|
||||||
# and sends a specific name to the upstream server
|
|
||||||
useModelName: "qwen:qwq"
|
|
||||||
|
|
||||||
# unlisted models do not show up in /v1/models or /upstream lists
|
1. Docker
|
||||||
# but they can still be requested as normal
|
2. Homebrew (OSX and Linux)
|
||||||
"qwen-unlisted":
|
3. From release binaries
|
||||||
unlisted: true
|
4. From source
|
||||||
cmd: llama-server --port 9999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||||
"docker-llama":
|
|
||||||
proxy: "http://127.0.0.1:9790"
|
|
||||||
cmd: >
|
|
||||||
docker run --name dockertest
|
|
||||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
|
||||||
|
|
||||||
# profiles eliminates swapping by running multiple models at the same time
|
Docker images with llama-swap and llama-server are built nightly.
|
||||||
#
|
|
||||||
# Tips:
|
|
||||||
# - each model must be listening on a unique address and port
|
|
||||||
# - the model name is in this format: "profile_name:model", like "coding:qwen"
|
|
||||||
# - the profile will load and unload all models in the profile at the same time
|
|
||||||
profiles:
|
|
||||||
coding:
|
|
||||||
- "llama"
|
|
||||||
- "qwen-unlisted"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Use Case Examples
|
```shell
|
||||||
|
# use CPU inference comes with the example config above
|
||||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
|
||||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
|
||||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
|
||||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
llama-s
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
|
||||||
|
|
||||||
Docker is the quickest way to try out llama-swap:
|
|
||||||
|
|
||||||
```
|
|
||||||
# use CPU inference
|
|
||||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||||
|
|
||||||
|
|
||||||
# qwen2.5 0.5B
|
# qwen2.5 0.5B
|
||||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
@@ -158,7 +107,6 @@ $ curl -s http://localhost:9292/v1/chat/completions \
|
|||||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||||
jq -r '.choices[0].message.content'
|
jq -r '.choices[0].message.content'
|
||||||
|
|
||||||
|
|
||||||
# SmolLM2 135M
|
# SmolLM2 135M
|
||||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
@@ -168,7 +116,7 @@ $ curl -s http://localhost:9292/v1/chat/completions \
|
|||||||
```
|
```
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Docker images are nightly ...</summary>
|
<summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
|
||||||
|
|
||||||
They include:
|
They include:
|
||||||
|
|
||||||
@@ -182,7 +130,7 @@ Specific versions are also available and are tagged with the llama-swap, archite
|
|||||||
|
|
||||||
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
||||||
-v /path/to/models:/models \
|
-v /path/to/models:/models \
|
||||||
-v /path/to/custom/config.yaml:/app/config.yaml \
|
-v /path/to/custom/config.yaml:/app/config.yaml \
|
||||||
@@ -191,34 +139,59 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
### Homebrew Install (macOS/Linux)
|
||||||
|
|
||||||
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Set up tap and install formula
|
||||||
|
brew tap mostlygeek/llama-swap
|
||||||
|
brew install llama-swap
|
||||||
|
# Run llama-swap
|
||||||
|
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||||
|
```
|
||||||
|
|
||||||
|
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
|
||||||
|
|
||||||
|
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||||
|
|
||||||
|
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
|
||||||
|
|
||||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
|
||||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||||
1. Run the binary with `llama-swap --config path/to/config.yaml`
|
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
|
||||||
|
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
|
||||||
|
Available flags:
|
||||||
|
- `--config`: Path to the configuration file (default: `config.yaml`).
|
||||||
|
- `--listen`: Address and port to listen on (default: `:8080`).
|
||||||
|
- `--version`: Show version information and exit.
|
||||||
|
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
|
||||||
|
|
||||||
### Building from source
|
### Building from source
|
||||||
|
|
||||||
1. Install golang for your system
|
1. Build requires golang and nodejs for the user interface.
|
||||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
1. `git clone https://github.com/mostlygeek/llama-swap.git`
|
||||||
1. `make clean all`
|
1. `make clean all`
|
||||||
1. Binaries will be in `build/` subdirectory
|
1. Binaries will be in `build/` subdirectory
|
||||||
|
|
||||||
## Monitoring Logs
|
## Monitoring Logs
|
||||||
|
|
||||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
|
||||||
|
|
||||||
Of course, CLI access is also supported:
|
CLI access is also supported:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
# sends up to the last 10KB of logs
|
# sends up to the last 10KB of logs
|
||||||
curl http://host/logs'
|
curl http://host/logs'
|
||||||
|
|
||||||
# streams logs
|
# streams combined logs
|
||||||
curl -Ns 'http://host/logs/stream'
|
curl -Ns 'http://host/logs/stream'
|
||||||
|
|
||||||
|
# just llama-swap's logs
|
||||||
|
curl -Ns 'http://host/logs/stream/proxy'
|
||||||
|
|
||||||
|
# just upstream's logs
|
||||||
|
curl -Ns 'http://host/logs/stream/upstream'
|
||||||
|
|
||||||
# stream and filter logs with linux pipes
|
# stream and filter logs with linux pipes
|
||||||
curl -Ns http://host/logs/stream | grep 'eval time'
|
curl -Ns http://host/logs/stream | grep 'eval time'
|
||||||
|
|
||||||
@@ -232,36 +205,9 @@ Any OpenAI compatible server would work. llama-swap was originally designed for
|
|||||||
|
|
||||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||||
|
|
||||||
## Systemd Unit Files
|
|
||||||
|
|
||||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
|
||||||
|
|
||||||
`/etc/systemd/system/llama-swap.service`
|
|
||||||
|
|
||||||
```
|
|
||||||
[Unit]
|
|
||||||
Description=llama-swap
|
|
||||||
After=network.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
User=nobody
|
|
||||||
|
|
||||||
# set this to match your environment
|
|
||||||
ExecStart=/path/to/llama-swap --config /path/to/llama-swap.config.yml
|
|
||||||
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=3
|
|
||||||
StartLimitBurst=3
|
|
||||||
StartLimitInterval=30
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
|
||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
<picture>
|
> [!NOTE]
|
||||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date&theme=dark" />
|
> ⭐️ Star this project to help others discover it!
|
||||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
|
||||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date" />
|
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||||
</picture>
|
|
||||||
|
|||||||
@@ -1,92 +1,249 @@
|
|||||||
# Seconds to wait for llama.cpp to be available to serve requests
|
# llama-swap YAML configuration example
|
||||||
# Default (and minimum): 15 seconds
|
# -------------------------------------
|
||||||
healthCheckTimeout: 15
|
#
|
||||||
|
# 💡 Tip - Use an LLM with this file!
|
||||||
|
# ====================================
|
||||||
|
# This example configuration is written to be LLM friendly. Try
|
||||||
|
# copying this file into an LLM and asking it to explain or generate
|
||||||
|
# sections for you.
|
||||||
|
# ====================================
|
||||||
|
|
||||||
# Log HTTP requests helpful for troubleshoot, defaults to False
|
# Usage notes:
|
||||||
logRequests: true
|
# - Below are all the available configuration options for llama-swap.
|
||||||
|
# - Settings noted as "required" must be in your configuration file
|
||||||
|
# - Settings noted as "optional" can be omitted
|
||||||
|
|
||||||
|
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||||
|
# - optional, default: 120
|
||||||
|
# - minimum value is 15 seconds, anything less will be set to this value
|
||||||
|
healthCheckTimeout: 500
|
||||||
|
|
||||||
|
# logLevel: sets the logging value
|
||||||
|
# - optional, default: info
|
||||||
|
# - Valid log levels: debug, info, warn, error
|
||||||
|
logLevel: info
|
||||||
|
|
||||||
|
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||||
|
# - optional, default: 1000
|
||||||
|
# - controls how many metrics are stored in memory before older ones are discarded
|
||||||
|
# - useful for limiting memory usage when processing large volumes of metrics
|
||||||
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
|
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||||
|
# - optional, default: 5800
|
||||||
|
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||||
|
# - it is automatically incremented for every model that uses it
|
||||||
|
startPort: 10001
|
||||||
|
|
||||||
|
# macros: a dictionary of string substitutions
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - macros are reusable snippets
|
||||||
|
# - used in a model's cmd, cmdStop, proxy and checkEndpoint
|
||||||
|
# - useful for reducing common configuration settings
|
||||||
|
macros:
|
||||||
|
"latest-llama": >
|
||||||
|
/path/to/llama-server/llama-server-ec9e0301
|
||||||
|
--port ${PORT}
|
||||||
|
|
||||||
|
# models: a dictionary of model configurations
|
||||||
|
# - required
|
||||||
|
# - each key is the model's ID, used in API requests
|
||||||
|
# - model settings have default values that are used if they are not defined here
|
||||||
|
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||||
|
# - below are examples of the all the settings a model can have
|
||||||
models:
|
models:
|
||||||
|
|
||||||
|
# keys are the model names used in API requests
|
||||||
"llama":
|
"llama":
|
||||||
cmd: >
|
# cmd: the command to run to start the inference server.
|
||||||
models/llama-server-osx
|
# - required
|
||||||
--port 9001
|
# - it is just a string, similar to what you would run on the CLI
|
||||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
# - using `|` allows for comments in the command, these will be parsed out
|
||||||
proxy: http://127.0.0.1:9001
|
# - macros can be used within cmd
|
||||||
|
cmd: |
|
||||||
|
# ${latest-llama} is a macro that is defined above
|
||||||
|
${latest-llama}
|
||||||
|
--model path/to/llama-8B-Q4_K_M.gguf
|
||||||
|
|
||||||
# list of model name aliases this llama.cpp instance can serve
|
# name: a display name for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
name: "llama 3.1 8B"
|
||||||
|
|
||||||
|
# description: a description for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
description: "A small but capable model used for quick testing"
|
||||||
|
|
||||||
|
# env: define an array of environment variables to inject into cmd's environment
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - each value is a single string
|
||||||
|
# - in the format: ENV_NAME=value
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||||
|
|
||||||
|
# proxy: the URL where llama-swap routes API requests
|
||||||
|
# - optional, default: http://localhost:${PORT}
|
||||||
|
# - if you used ${PORT} in cmd this can be omitted
|
||||||
|
# - if you use a custom port in cmd this *must* be set
|
||||||
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
|
# aliases: alternative model names that this model configuration is used for
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - aliases must be unique globally
|
||||||
|
# - useful for impersonating a specific model
|
||||||
aliases:
|
aliases:
|
||||||
- gpt-4o-mini
|
- "gpt-4o-mini"
|
||||||
|
- "gpt-3.5-turbo"
|
||||||
|
|
||||||
# check this path for a HTTP 200 response for the server to be ready
|
# checkEndpoint: URL path to check if the server is ready
|
||||||
checkEndpoint: /health
|
# - optional, default: /health
|
||||||
|
# - endpoint is expected to return an HTTP 200 response
|
||||||
|
# - all requests wait until the endpoint is ready or fails
|
||||||
|
# - use "none" to skip endpoint health checking
|
||||||
|
checkEndpoint: /custom-endpoint
|
||||||
|
|
||||||
# unload model after 5 seconds
|
# ttl: automatically unload the model after ttl seconds
|
||||||
ttl: 5
|
# - optional, default: 0
|
||||||
|
# - ttl values must be a value greater than 0
|
||||||
|
# - a value of 0 disables automatic unloading of the model
|
||||||
|
ttl: 60
|
||||||
|
|
||||||
"qwen":
|
# useModelName: override the model name that is sent to upstream server
|
||||||
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
# - optional, default: ""
|
||||||
proxy: http://127.0.0.1:9002
|
# - useful for when the upstream server expects a specific model name that
|
||||||
aliases:
|
# is different from the model's ID
|
||||||
- gpt-3.5-turbo
|
useModelName: "qwen:qwq"
|
||||||
|
|
||||||
# Embedding example with Nomic
|
# filters: a dictionary of filter settings
|
||||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
# - optional, default: empty dictionary
|
||||||
"nomic":
|
# - only strip_params is currently supported
|
||||||
proxy: http://127.0.0.1:9005
|
filters:
|
||||||
cmd: >
|
# strip_params: a comma separated list of parameters to remove from the request
|
||||||
models/llama-server-osx --port 9005
|
# - optional, default: ""
|
||||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
# - useful for server side enforcement of sampling parameters
|
||||||
--ctx-size 8192
|
# - the `model` parameter can never be removed
|
||||||
--batch-size 8192
|
# - can be any JSON key in the request body
|
||||||
--rope-scaling yarn
|
# - recommended to stick to sampling parameters
|
||||||
--rope-freq-scale 0.75
|
strip_params: "temperature, top_p, top_k"
|
||||||
-ngl 99
|
|
||||||
--embeddings
|
|
||||||
|
|
||||||
# Reranking example with bge-reranker
|
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
# - optional, default: 0
|
||||||
"bge-reranker":
|
# - useful for limiting the number of active parallel requests a model can process
|
||||||
proxy: http://127.0.0.1:9006
|
# - must be set per model
|
||||||
cmd: >
|
# - any number greater than 0 will override the internal default value of 10
|
||||||
models/llama-server-osx --port 9006
|
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
# - recommended to be omitted and the default used
|
||||||
--ctx-size 8192
|
concurrencyLimit: 0
|
||||||
--reranking
|
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
# Unlisted model example:
|
||||||
"dockertest":
|
"qwen-unlisted":
|
||||||
proxy: "http://127.0.0.1:9790"
|
# unlisted: boolean, true or false
|
||||||
cmd: >
|
# - optional, default: false
|
||||||
docker run --name dockertest
|
# - unlisted models do not show up in /v1/models api requests
|
||||||
--init --rm -p 9790:8080 -v /mnt/nvme/models:/models
|
# - can be requested as normal through all apis
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
unlisted: true
|
||||||
|
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
|
# Docker example:
|
||||||
|
# container runtimes like Docker and Podman can be used reliably with
|
||||||
|
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||||
|
"docker-llama":
|
||||||
|
proxy: "http://127.0.0.1:${PORT}"
|
||||||
|
cmd: |
|
||||||
|
docker run --name ${MODEL_ID}
|
||||||
|
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
"simple":
|
# cmdStop: command to run to stop the model gracefully
|
||||||
# example of setting environment variables
|
# - optional, default: ""
|
||||||
env:
|
# - useful for stopping commands managed by another system
|
||||||
- CUDA_VISIBLE_DEVICES=0,1
|
# - the upstream's process id is available in the ${PID} macro
|
||||||
- env1=hello
|
#
|
||||||
cmd: build/simple-responder --port 8999
|
# When empty, llama-swap has this default behaviour:
|
||||||
proxy: http://127.0.0.1:8999
|
# - on POSIX systems: a SIGTERM signal is sent
|
||||||
unlisted: true
|
# - on Windows, calls taskkill to stop the process
|
||||||
|
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||||
|
cmdStop: docker stop ${MODEL_ID}
|
||||||
|
|
||||||
# use "none" to skip check. Caution this may cause some requests to fail
|
# groups: a dictionary of group settings
|
||||||
# until the upstream server is ready for traffic
|
# - optional, default: empty dictionary
|
||||||
checkEndpoint: none
|
# - provides advanced controls over model swapping behaviour
|
||||||
|
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
||||||
|
# - model IDs must be defined in the Models section
|
||||||
|
# - a model can only be a member of one group
|
||||||
|
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||||
|
# - see issue #109 for details
|
||||||
|
#
|
||||||
|
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||||
|
groups:
|
||||||
|
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
||||||
|
# to run a time across the whole llama-swap instance
|
||||||
|
"group1":
|
||||||
|
# swap: controls the model swapping behaviour in within the group
|
||||||
|
# - optional, default: true
|
||||||
|
# - true : only one model is allowed to run at a time
|
||||||
|
# - false: all models can run together, no swapping
|
||||||
|
swap: true
|
||||||
|
|
||||||
# don't use these, just for testing if things are broken
|
# exclusive: controls how the group affects other groups
|
||||||
"broken":
|
# - optional, default: true
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
# - true: causes all other groups to unload when this group runs a model
|
||||||
proxy: http://127.0.0.1:8999
|
# - false: does not affect other groups
|
||||||
unlisted: true
|
exclusive: true
|
||||||
"broken_timeout":
|
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
|
||||||
proxy: http://127.0.0.1:9000
|
|
||||||
unlisted: true
|
|
||||||
|
|
||||||
# creating a coding profile with models for code generation and general questions
|
# members references the models defined above
|
||||||
profiles:
|
# required
|
||||||
coding:
|
members:
|
||||||
- "qwen"
|
- "llama"
|
||||||
- "llama"
|
- "qwen-unlisted"
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# - in group2 all models can run at the same time
|
||||||
|
# - when a different group is loaded it causes all running models in this group to unload
|
||||||
|
"group2":
|
||||||
|
swap: false
|
||||||
|
|
||||||
|
# exclusive: false does not unload other groups when a model in group2 is requested
|
||||||
|
# - the models in group2 will be loaded but will not unload any other groups
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "docker-llama"
|
||||||
|
- "modelA"
|
||||||
|
- "modelB"
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# - a persistent group, prevents other groups from unloading it
|
||||||
|
"forever":
|
||||||
|
# persistent: prevents over groups from unloading the models in this group
|
||||||
|
# - optional, default: false
|
||||||
|
# - does not affect individual model behaviour
|
||||||
|
persistent: true
|
||||||
|
|
||||||
|
# set swap/exclusive to false to prevent swapping inside the group
|
||||||
|
# and the unloading of other groups
|
||||||
|
swap: false
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "forever-modelA"
|
||||||
|
- "forever-modelB"
|
||||||
|
- "forever-modelc"
|
||||||
|
|
||||||
|
# hooks: a dictionary of event triggers and actions
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - the only supported hook is on_startup
|
||||||
|
hooks:
|
||||||
|
# on_startup: a dictionary of actions to perform on startup
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - the only supported action is preload
|
||||||
|
on_startup:
|
||||||
|
# preload: a list of model ids to load on startup
|
||||||
|
# - optional, default: empty list
|
||||||
|
# - model names must match keys in the models sections
|
||||||
|
# - when preloading multiple models at once, define a group
|
||||||
|
# otherwise models will be loaded and swapped out
|
||||||
|
preload:
|
||||||
|
- "llama"
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
healthCheckTimeout: 300
|
healthCheckTimeout: 300
|
||||||
logRequests: true
|
logRequests: true
|
||||||
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
models:
|
models:
|
||||||
"qwen2.5":
|
"qwen2.5":
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
|
||||||
|
|
||||||
|
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default initializes a default in-process dispatcher
|
||||||
|
var Default = NewDispatcherConfig(25000)
|
||||||
|
|
||||||
|
// On subscribes to an event, the type of the event will be automatically
|
||||||
|
// inferred from the provided type. Must be constant for this to work. This
|
||||||
|
// functions same way as Subscribe() but uses the default dispatcher instead.
|
||||||
|
func On[T Event](handler func(T)) context.CancelFunc {
|
||||||
|
return Subscribe(Default, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnType subscribes to an event with the specified event type. This functions
|
||||||
|
// same way as SubscribeTo() but uses the default dispatcher instead.
|
||||||
|
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
|
||||||
|
return SubscribeTo(Default, eventType, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit writes an event into the dispatcher. This functions same way as
|
||||||
|
// Publish() but uses the default dispatcher instead.
|
||||||
|
func Emit[T Event](ev T) {
|
||||||
|
Publish(Default, ev)
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
|
||||||
|
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
|
||||||
|
*/
|
||||||
|
func BenchmarkSubscribeConcurrent(b *testing.B) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
unsub := Subscribe(d, func(ev MyEvent1) {})
|
||||||
|
unsub()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultPublish(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
var count int64
|
||||||
|
defer On(func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&count, 1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
defer OnType(TypeEvent1, func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&count, 1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Publish
|
||||||
|
wg.Add(4)
|
||||||
|
Emit(MyEvent1{})
|
||||||
|
Emit(MyEvent1{})
|
||||||
|
|
||||||
|
// Wait and check
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, int64(4), count)
|
||||||
|
}
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for details.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event represents an event contract
|
||||||
|
type Event interface {
|
||||||
|
Type() uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// registry holds an immutable sorted array of event mappings
|
||||||
|
type registry struct {
|
||||||
|
keys []uint32 // Event types (sorted)
|
||||||
|
grps []any // Corresponding subscribers
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Dispatcher -------------------------------------
|
||||||
|
|
||||||
|
// Dispatcher represents an event dispatcher.
|
||||||
|
type Dispatcher struct {
|
||||||
|
subs atomic.Pointer[registry] // Atomic pointer to immutable array
|
||||||
|
done chan struct{} // Cancellation
|
||||||
|
maxQueue int // Maximum queue size per consumer
|
||||||
|
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDispatcher creates a new dispatcher of events.
|
||||||
|
func NewDispatcher() *Dispatcher {
|
||||||
|
return NewDispatcherConfig(50000)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
|
||||||
|
func NewDispatcherConfig(maxQueue int) *Dispatcher {
|
||||||
|
d := &Dispatcher{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
maxQueue: maxQueue,
|
||||||
|
}
|
||||||
|
|
||||||
|
d.subs.Store(®istry{
|
||||||
|
keys: make([]uint32, 0, 16),
|
||||||
|
grps: make([]any, 0, 16),
|
||||||
|
})
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the dispatcher
|
||||||
|
func (d *Dispatcher) Close() error {
|
||||||
|
close(d.done)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClosed returns whether the dispatcher is closed or not
|
||||||
|
func (d *Dispatcher) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-d.done:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findGroup performs a lock-free binary search for the event type
|
||||||
|
func (d *Dispatcher) findGroup(eventType uint32) any {
|
||||||
|
reg := d.subs.Load()
|
||||||
|
keys := reg.keys
|
||||||
|
|
||||||
|
// Inlined binary search for better cache locality
|
||||||
|
left, right := 0, len(keys)
|
||||||
|
for left < right {
|
||||||
|
mid := left + (right-left)/2
|
||||||
|
if keys[mid] < eventType {
|
||||||
|
left = mid + 1
|
||||||
|
} else {
|
||||||
|
right = mid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if left < len(keys) && keys[left] == eventType {
|
||||||
|
return reg.grps[left]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes to an event, the type of the event will be automatically
|
||||||
|
// inferred from the provided type. Must be constant for this to work.
|
||||||
|
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
|
||||||
|
var event T
|
||||||
|
return SubscribeTo(broker, event.Type(), handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscribeTo subscribes to an event with the specified event type.
|
||||||
|
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
|
||||||
|
if broker.isClosed() {
|
||||||
|
panic(errClosed)
|
||||||
|
}
|
||||||
|
|
||||||
|
broker.mu.Lock()
|
||||||
|
defer broker.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if group already exists
|
||||||
|
if existing := broker.findGroup(eventType); existing != nil {
|
||||||
|
grp := groupOf[T](eventType, existing)
|
||||||
|
sub := grp.Add(handler)
|
||||||
|
return func() {
|
||||||
|
grp.Del(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new group
|
||||||
|
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
|
||||||
|
sub := grp.Add(handler)
|
||||||
|
|
||||||
|
// Copy-on-write: insert new entry in sorted position
|
||||||
|
old := broker.subs.Load()
|
||||||
|
idx := sort.Search(len(old.keys), func(i int) bool {
|
||||||
|
return old.keys[i] >= eventType
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create new arrays with space for one more element
|
||||||
|
newKeys := make([]uint32, len(old.keys)+1)
|
||||||
|
newGrps := make([]any, len(old.grps)+1)
|
||||||
|
|
||||||
|
// Copy elements before insertion point
|
||||||
|
copy(newKeys[:idx], old.keys[:idx])
|
||||||
|
copy(newGrps[:idx], old.grps[:idx])
|
||||||
|
|
||||||
|
// Insert new element
|
||||||
|
newKeys[idx] = eventType
|
||||||
|
newGrps[idx] = grp
|
||||||
|
|
||||||
|
// Copy elements after insertion point
|
||||||
|
copy(newKeys[idx+1:], old.keys[idx:])
|
||||||
|
copy(newGrps[idx+1:], old.grps[idx:])
|
||||||
|
|
||||||
|
// Atomically store the new registry (mutex ensures no concurrent writers)
|
||||||
|
newReg := ®istry{keys: newKeys, grps: newGrps}
|
||||||
|
broker.subs.Store(newReg)
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
grp.Del(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish writes an event into the dispatcher
|
||||||
|
func Publish[T Event](broker *Dispatcher, ev T) {
|
||||||
|
eventType := ev.Type()
|
||||||
|
if sub := broker.findGroup(eventType); sub != nil {
|
||||||
|
group := groupOf[T](eventType, sub)
|
||||||
|
group.Broadcast(ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count counts the number of subscribers, this is for testing only.
|
||||||
|
func (d *Dispatcher) count(eventType uint32) int {
|
||||||
|
if group := d.findGroup(eventType); group != nil {
|
||||||
|
return group.(interface{ Count() int }).Count()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupOf casts the subscriber group to the specified generic type
|
||||||
|
func groupOf[T Event](eventType uint32, subs any) *group[T] {
|
||||||
|
if group, ok := subs.(*group[T]); ok {
|
||||||
|
return group
|
||||||
|
}
|
||||||
|
|
||||||
|
panic(errConflict[T](eventType, subs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Subscriber -------------------------------------
|
||||||
|
|
||||||
|
// consumer represents a consumer with a message queue
|
||||||
|
type consumer[T Event] struct {
|
||||||
|
queue []T // Current work queue
|
||||||
|
stop bool // Stop signal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen listens to the event queue and processes events
|
||||||
|
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
|
||||||
|
pending := make([]T, 0, 128)
|
||||||
|
|
||||||
|
for {
|
||||||
|
c.L.Lock()
|
||||||
|
for len(s.queue) == 0 {
|
||||||
|
switch {
|
||||||
|
case s.stop:
|
||||||
|
c.L.Unlock()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap buffers and reset the current queue
|
||||||
|
temp := s.queue
|
||||||
|
s.queue = pending[:0]
|
||||||
|
pending = temp
|
||||||
|
c.L.Unlock()
|
||||||
|
|
||||||
|
// Outside of the critical section, process the work
|
||||||
|
for _, event := range pending {
|
||||||
|
fn(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify potential publishers waiting due to backpressure
|
||||||
|
c.Broadcast()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Subscriber Group -------------------------------------
|
||||||
|
|
||||||
|
// group represents a consumer group
|
||||||
|
type group[T Event] struct {
|
||||||
|
cond *sync.Cond
|
||||||
|
subs []*consumer[T]
|
||||||
|
maxQueue int // Maximum queue size per consumer
|
||||||
|
maxLen int // Current maximum queue length across all consumers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast sends an event to all consumers
|
||||||
|
func (s *group[T]) Broadcast(ev T) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Calculate current maximum queue length
|
||||||
|
s.maxLen = 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
if len(sub.queue) > s.maxLen {
|
||||||
|
s.maxLen = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backpressure: wait if queues are full
|
||||||
|
for s.maxLen >= s.maxQueue {
|
||||||
|
s.cond.Wait()
|
||||||
|
|
||||||
|
// Recalculate after wakeup
|
||||||
|
s.maxLen = 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
if len(sub.queue) > s.maxLen {
|
||||||
|
s.maxLen = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add event to all queues and track new maximum
|
||||||
|
newMax := 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
sub.queue = append(sub.queue, ev)
|
||||||
|
if len(sub.queue) > newMax {
|
||||||
|
newMax = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.maxLen = newMax
|
||||||
|
s.cond.Broadcast() // Wake consumers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a subscriber to the list
|
||||||
|
func (s *group[T]) Add(handler func(T)) *consumer[T] {
|
||||||
|
sub := &consumer[T]{
|
||||||
|
queue: make([]T, 0, 64),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the consumer to the list of active consumers
|
||||||
|
s.cond.L.Lock()
|
||||||
|
s.subs = append(s.subs, sub)
|
||||||
|
s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Start listening
|
||||||
|
go sub.Listen(s.cond, handler)
|
||||||
|
return sub
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del removes a subscriber from the list
|
||||||
|
func (s *group[T]) Del(sub *consumer[T]) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Search and remove the subscriber
|
||||||
|
sub.stop = true
|
||||||
|
for i, v := range s.subs {
|
||||||
|
if v == sub {
|
||||||
|
copy(s.subs[i:], s.subs[i+1:])
|
||||||
|
s.subs = s.subs[:len(s.subs)-1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Debugging -------------------------------------
|
||||||
|
|
||||||
|
var errClosed = fmt.Errorf("event dispatcher is closed")
|
||||||
|
|
||||||
|
// Count returns the number of subscribers in this group
|
||||||
|
func (s *group[T]) Count() int {
|
||||||
|
return len(s.subs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation of the type
|
||||||
|
func (s *group[T]) String() string {
|
||||||
|
typ := reflect.TypeOf(s).String()
|
||||||
|
idx := strings.LastIndex(typ, "/")
|
||||||
|
typ = typ[idx+1 : len(typ)-1]
|
||||||
|
return typ
|
||||||
|
}
|
||||||
|
|
||||||
|
// errConflict returns a conflict message
|
||||||
|
func errConflict[T any](eventType uint32, existing any) string {
|
||||||
|
var want T
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
|
||||||
|
want, existing, eventType,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPublish(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Subscribe, must be received in order
|
||||||
|
var count int64
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Publish
|
||||||
|
wg.Add(3)
|
||||||
|
Publish(d, MyEvent1{Number: 1})
|
||||||
|
Publish(d, MyEvent1{Number: 2})
|
||||||
|
Publish(d, MyEvent1{Number: 3})
|
||||||
|
|
||||||
|
// Wait and check
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, int64(3), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnsubscribe(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||||
|
unsubscribe := Subscribe(d, func(ev MyEvent1) {
|
||||||
|
// Nothing
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, 1, d.count(TypeEvent1))
|
||||||
|
unsubscribe()
|
||||||
|
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrent(t *testing.T) {
|
||||||
|
const max = 1000000
|
||||||
|
var count int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
if current := atomic.AddInt64(&count, 1); current == max {
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Asynchronously publish
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < max; i++ {
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
// Subscriber that does nothing
|
||||||
|
})()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, max, int(count))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscribeDifferentType(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublishDifferentType(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseDispatcher(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
|
||||||
|
|
||||||
|
assert.NoError(t, d.Close())
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix(t *testing.T) {
|
||||||
|
const amount = 1000
|
||||||
|
for _, subs := range []int{1, 10, 100} {
|
||||||
|
for _, topics := range []int{1, 10} {
|
||||||
|
expected := subs * topics * amount
|
||||||
|
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
|
||||||
|
var count atomic.Int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(expected)
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
for i := 0; i < subs; i++ {
|
||||||
|
for id := 0; id < topics; id++ {
|
||||||
|
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
|
||||||
|
count.Add(1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < amount; n++ {
|
||||||
|
for id := 0; id < topics; id++ {
|
||||||
|
go Publish(d, MyEvent3{ID: id})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, expected, int(count.Load()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentSubscriptionRace(t *testing.T) {
|
||||||
|
// This test specifically targets the race condition that occurs when multiple
|
||||||
|
// goroutines try to subscribe to different event types simultaneously.
|
||||||
|
// Without the CAS loop, subscriptions could be lost due to registry corruption.
|
||||||
|
|
||||||
|
const numGoroutines = 100
|
||||||
|
const numEventTypes = 50
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer d.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var receivedCount int64
|
||||||
|
var subscribedTypes sync.Map // Thread-safe map
|
||||||
|
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
// Start multiple goroutines that subscribe to different event types concurrently
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(goroutineID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Each goroutine subscribes to a unique event type
|
||||||
|
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
|
||||||
|
|
||||||
|
// Subscribe to the event type
|
||||||
|
SubscribeTo(d, eventType, func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(&receivedCount, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Record that this type was subscribed
|
||||||
|
subscribedTypes.Store(eventType, true)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all subscriptions to complete
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Count the number of unique event types subscribed
|
||||||
|
expectedTypes := 0
|
||||||
|
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||||
|
expectedTypes++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Small delay to ensure all subscriptions are fully processed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Publish events to each subscribed type
|
||||||
|
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||||
|
eventType := key.(uint32)
|
||||||
|
Publish(d, MyEvent3{ID: int(eventType)})
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wait for all events to be processed
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify that we received at least the expected number of events
|
||||||
|
// (there might be more if multiple goroutines subscribed to the same event type)
|
||||||
|
received := atomic.LoadInt64(&receivedCount)
|
||||||
|
assert.GreaterOrEqual(t, int(received), expectedTypes,
|
||||||
|
"Should have received at least %d events, got %d", expectedTypes, received)
|
||||||
|
|
||||||
|
// Verify that we have the expected number of unique event types
|
||||||
|
assert.Equal(t, numEventTypes, expectedTypes,
|
||||||
|
"Should have exactly %d unique event types", numEventTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentHandlerRegistration(t *testing.T) {
|
||||||
|
const numGoroutines = 100
|
||||||
|
|
||||||
|
// Test concurrent subscriptions to the same event type
|
||||||
|
t.Run("SameEventType", func(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var handlerCount int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Start multiple goroutines subscribing to the same event type (0x1)
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&handlerCount, 1)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Verify all handlers were registered by publishing an event
|
||||||
|
atomic.StoreInt64(&handlerCount, 0)
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
|
||||||
|
// Small delay to ensure all handlers have executed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
|
||||||
|
"Not all handlers were registered due to race condition")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test concurrent subscriptions to different event types
|
||||||
|
t.Run("DifferentEventTypes", func(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
receivedEvents := make(map[uint32]*int64)
|
||||||
|
|
||||||
|
// Create multiple event types and subscribe concurrently
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
eventType := uint32(100 + i)
|
||||||
|
counter := new(int64)
|
||||||
|
receivedEvents[eventType] = counter
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(et uint32, cnt *int64) {
|
||||||
|
defer wg.Done()
|
||||||
|
SubscribeTo(d, et, func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(cnt, 1)
|
||||||
|
})
|
||||||
|
}(eventType, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Publish events to all types
|
||||||
|
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
|
||||||
|
Publish(d, MyEvent3{ID: int(eventType)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small delay to ensure all handlers have executed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify all event types received their events
|
||||||
|
for eventType, counter := range receivedEvents {
|
||||||
|
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
|
||||||
|
"Event type %d did not receive its event", eventType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackpressure(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
d.maxQueue = 10
|
||||||
|
|
||||||
|
var processedCount int64
|
||||||
|
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(&processedCount, 1)
|
||||||
|
})
|
||||||
|
defer unsub()
|
||||||
|
|
||||||
|
const eventsToPublish = 1000
|
||||||
|
for i := 0; i < eventsToPublish; i++ {
|
||||||
|
Publish(d, MyEvent3{ID: 0x200})
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify all events were eventually processed
|
||||||
|
finalProcessed := atomic.LoadInt64(&processedCount)
|
||||||
|
assert.Equal(t, int64(eventsToPublish), finalProcessed)
|
||||||
|
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Test Events -------------------------------------
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeEvent1 = 0x1
|
||||||
|
TypeEvent2 = 0x2
|
||||||
|
)
|
||||||
|
|
||||||
|
type MyEvent1 struct {
|
||||||
|
Number int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
|
||||||
|
|
||||||
|
type MyEvent2 struct {
|
||||||
|
Text string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
|
||||||
|
|
||||||
|
type MyEvent3 struct {
|
||||||
|
ID int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
# aider, QwQ, Qwen-Coder 2.5 and llama-swap
|
||||||
|
|
||||||
|
This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together.
|
||||||
|
|
||||||
|
## Here's what you you need:
|
||||||
|
|
||||||
|
- aider - [installation docs](https://aider.chat/docs/install.html)
|
||||||
|
- llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases)
|
||||||
|
- llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases)
|
||||||
|
- [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models
|
||||||
|
- 24GB VRAM video card
|
||||||
|
|
||||||
|
## Running aider
|
||||||
|
|
||||||
|
The goal is getting this command line to work:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
aider --architect \
|
||||||
|
--no-show-model-warnings \
|
||||||
|
--model openai/QwQ \
|
||||||
|
--editor-model openai/qwen-coder-32B \
|
||||||
|
--model-settings-file aider.model.settings.yml \
|
||||||
|
--openai-api-key "sk-na" \
|
||||||
|
--openai-api-base "http://10.0.1.24:8080/v1" \
|
||||||
|
```
|
||||||
|
|
||||||
|
Set `--openai-api-base` to the IP and port where your llama-swap is running.
|
||||||
|
|
||||||
|
## Create an aider model settings file
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# aider.model.settings.yml
|
||||||
|
|
||||||
|
#
|
||||||
|
# !!! important: model names must match llama-swap configuration names !!!
|
||||||
|
#
|
||||||
|
|
||||||
|
- name: "openai/QwQ"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.95
|
||||||
|
top_k: 40
|
||||||
|
presence_penalty: 0.1
|
||||||
|
repetition_penalty: 1
|
||||||
|
num_ctx: 16384
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
weak_model_name: "openai/qwen-coder-32B"
|
||||||
|
editor_model_name: "openai/qwen-coder-32B"
|
||||||
|
|
||||||
|
- name: "openai/qwen-coder-32B"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 20
|
||||||
|
repetition_penalty: 1.05
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
editor_edit_format: editor-diff
|
||||||
|
editor_model_name: "openai/qwen-coder-32B"
|
||||||
|
```
|
||||||
|
|
||||||
|
## llama-swap configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# config.yaml
|
||||||
|
|
||||||
|
# The parameters are tweaked to fit model+context into 24GB VRAM GPUs
|
||||||
|
models:
|
||||||
|
"qwen-coder-32B":
|
||||||
|
proxy: "http://127.0.0.1:8999"
|
||||||
|
cmd: >
|
||||||
|
/path/to/llama-server
|
||||||
|
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||||
|
--ctx-size 16000
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
-ngl 99
|
||||||
|
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||||
|
|
||||||
|
"QwQ":
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
cmd: >
|
||||||
|
/path/to/llama-server
|
||||||
|
--host 127.0.0.1 --port 9503 --flash-attn --metrics--slots
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
--ctx-size 32000
|
||||||
|
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||||
|
--temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5
|
||||||
|
--min-p 0.01 --top-k 40 --top-p 0.95
|
||||||
|
-ngl 99
|
||||||
|
--model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced, Dual GPU Configuration
|
||||||
|
|
||||||
|
If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder.
|
||||||
|
|
||||||
|
In llama-swap's configuration file:
|
||||||
|
|
||||||
|
1. add a `profiles` section with `aider` as the profile name
|
||||||
|
2. using the `env` field to specify the GPU IDs for each model
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# config.yaml
|
||||||
|
|
||||||
|
# Add a profile for aider
|
||||||
|
profiles:
|
||||||
|
aider:
|
||||||
|
- qwen-coder-32B
|
||||||
|
- QwQ
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen-coder-32B":
|
||||||
|
# manually set the GPU to run on
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
proxy: "http://127.0.0.1:8999"
|
||||||
|
cmd: /path/to/llama-server ...
|
||||||
|
|
||||||
|
"QwQ":
|
||||||
|
# manually set the GPU to run on
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=1"
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
cmd: /path/to/llama-server ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Append the profile tag, `aider:`, to the model names in the model settings file
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# aider.model.settings.yml
|
||||||
|
- name: "openai/aider:QwQ"
|
||||||
|
weak_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||||
|
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||||
|
|
||||||
|
- name: "openai/aider:qwen-coder-32B"
|
||||||
|
editor_model_name: "openai/aider:qwen-coder-32B-aider"
|
||||||
|
```
|
||||||
|
|
||||||
|
Run aider with:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ aider --architect \
|
||||||
|
--no-show-model-warnings \
|
||||||
|
--model openai/aider:QwQ \
|
||||||
|
--editor-model openai/aider:qwen-coder-32B \
|
||||||
|
--config aider.conf.yml \
|
||||||
|
--model-settings-file aider.model.settings.yml
|
||||||
|
--openai-api-key "sk-na" \
|
||||||
|
--openai-api-base "http://10.0.1.24:8080/v1"
|
||||||
|
```
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# this makes use of llama-swap's profile feature to
|
||||||
|
# keep the architect and editor models in VRAM on different GPUs
|
||||||
|
|
||||||
|
- name: "openai/aider:QwQ"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.95
|
||||||
|
top_k: 40
|
||||||
|
presence_penalty: 0.1
|
||||||
|
repetition_penalty: 1
|
||||||
|
num_ctx: 16384
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
weak_model_name: "openai/aider:qwen-coder-32B"
|
||||||
|
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||||
|
|
||||||
|
- name: "openai/aider:qwen-coder-32B"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 20
|
||||||
|
repetition_penalty: 1.05
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
editor_edit_format: editor-diff
|
||||||
|
editor_model_name: "openai/aider:qwen-coder-32B"
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
- name: "openai/QwQ"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.95
|
||||||
|
top_k: 40
|
||||||
|
presence_penalty: 0.1
|
||||||
|
repetition_penalty: 1
|
||||||
|
num_ctx: 16384
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
weak_model_name: "openai/qwen-coder-32B"
|
||||||
|
editor_model_name: "openai/qwen-coder-32B"
|
||||||
|
|
||||||
|
- name: "openai/qwen-coder-32B"
|
||||||
|
edit_format: diff
|
||||||
|
extra_params:
|
||||||
|
max_tokens: 16384
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 20
|
||||||
|
repetition_penalty: 1.05
|
||||||
|
use_temperature: 0.6
|
||||||
|
reasoning_tag: think
|
||||||
|
editor_edit_format: editor-diff
|
||||||
|
editor_model_name: "openai/qwen-coder-32B"
|
||||||
|
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
healthCheckTimeout: 300
|
||||||
|
logLevel: debug
|
||||||
|
|
||||||
|
profiles:
|
||||||
|
aider:
|
||||||
|
- qwen-coder-32B
|
||||||
|
- QwQ
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen-coder-32B":
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0"
|
||||||
|
aliases:
|
||||||
|
- coder
|
||||||
|
proxy: "http://127.0.0.1:8999"
|
||||||
|
|
||||||
|
# set appropriate paths for your environment
|
||||||
|
cmd: >
|
||||||
|
/path/to/llama-server
|
||||||
|
--host 127.0.0.1 --port 8999 --flash-attn --slots
|
||||||
|
--ctx-size 16000
|
||||||
|
--ctx-size-draft 16000
|
||||||
|
--model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf
|
||||||
|
--model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf
|
||||||
|
-ngl 99 -ngld 99
|
||||||
|
--draft-max 16 --draft-min 4 --draft-p-min 0.4
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
"QwQ":
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=1"
|
||||||
|
proxy: "http://127.0.0.1:9503"
|
||||||
|
|
||||||
|
# set appropriate paths for your environment
|
||||||
|
cmd: >
|
||||||
|
/path/to/llama-server
|
||||||
|
--host 127.0.0.1 --port 9503
|
||||||
|
--flash-attn --metrics
|
||||||
|
--slots
|
||||||
|
--model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf
|
||||||
|
--cache-type-k q8_0 --cache-type-v q8_0
|
||||||
|
--ctx-size 32000
|
||||||
|
--samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc"
|
||||||
|
--temp 0.6
|
||||||
|
--repeat-penalty 1.1
|
||||||
|
--dry-multiplier 0.5
|
||||||
|
--min-p 0.01
|
||||||
|
--top-k 40
|
||||||
|
--top-p 0.95
|
||||||
|
-ngl 99 -ngld 99
|
||||||
@@ -3,8 +3,9 @@ module github.com/mostlygeek/llama-swap
|
|||||||
go 1.23.0
|
go 1.23.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/billziss-gh/golib v0.2.0
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
@@ -37,7 +38,7 @@ require (
|
|||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.36.0 // indirect
|
golang.org/x/crypto v0.36.0 // indirect
|
||||||
golang.org/x/net v0.37.0 // indirect
|
golang.org/x/net v0.38.0 // indirect
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.1 // indirect
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8=
|
||||||
|
github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
@@ -9,12 +11,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
@@ -23,9 +29,9 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
|||||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
@@ -74,32 +80,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
|||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
|
||||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
|
||||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
|
||||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
|
||||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
|
||||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
|
||||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
|
||||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
|
||||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
|||||||
|
After Width: | Height: | Size: 351 KiB |
@@ -1,25 +1,35 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
var version string = "0"
|
var (
|
||||||
var commit string = "abcd1234"
|
version string = "0"
|
||||||
var date = "unknown"
|
commit string = "abcd1234"
|
||||||
|
date string = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
configPath := flag.String("config", "config.yaml", "config file name")
|
configPath := flag.String("config", "config.yaml", "config file name")
|
||||||
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
listenStr := flag.String("listen", ":8080", "listen ip/port")
|
||||||
showVersion := flag.Bool("version", false, "show version of build")
|
showVersion := flag.Bool("version", false, "show version of build")
|
||||||
|
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
|
||||||
|
|
||||||
flag.Parse() // Parse the command-line flags
|
flag.Parse() // Parse the command-line flags
|
||||||
|
|
||||||
@@ -34,26 +44,145 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(config.Profiles) > 0 {
|
||||||
|
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||||
|
}
|
||||||
|
|
||||||
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
if mode := os.Getenv("GIN_MODE"); mode != "" {
|
||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
} else {
|
} else {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyManager := proxy.New(config)
|
// Setup channels for server management
|
||||||
|
exitChan := make(chan struct{})
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Create server with initial handler
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: *listenStr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support for watching config and reloading when it changes
|
||||||
|
reloadProxyManager := func() {
|
||||||
|
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
|
config, err = proxy.LoadConfig(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Configuration Changed")
|
||||||
|
currentPM.Shutdown()
|
||||||
|
srv.Handler = proxy.New(config)
|
||||||
|
fmt.Println("Configuration Reloaded")
|
||||||
|
|
||||||
|
// wait a few seconds and tell any UI to reload
|
||||||
|
time.AfterFunc(3*time.Second, func() {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateEnd,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
config, err = proxy.LoadConfig(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
srv.Handler = proxy.New(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// load the initial proxy manager
|
||||||
|
reloadProxyManager()
|
||||||
|
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||||
|
if *watchConfig {
|
||||||
|
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||||
|
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||||
|
debouncedReload()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
fmt.Println("Watching Configuration for changes")
|
||||||
|
go func() {
|
||||||
|
absConfigPath, err := filepath.Abs(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
configDir := filepath.Dir(absConfigPath)
|
||||||
|
err = watcher.Add(configDir)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer watcher.Close()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case changeEvent := <-watcher.Events:
|
||||||
|
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
||||||
|
// the change for k8s configmap
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case err := <-watcher.Errors:
|
||||||
|
log.Printf("File watcher error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown on signal
|
||||||
go func() {
|
go func() {
|
||||||
<-sigChan
|
sig := <-sigChan
|
||||||
fmt.Println("Shutting down llama-swap")
|
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||||
proxyManager.Shutdown()
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
os.Exit(0)
|
defer cancel()
|
||||||
|
|
||||||
|
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
|
pm.Shutdown()
|
||||||
|
} else {
|
||||||
|
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
fmt.Printf("Server shutdown error: %v\n", err)
|
||||||
|
}
|
||||||
|
close(exitChan)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
fmt.Println("llama-swap listening on " + *listenStr)
|
// Start server
|
||||||
if err := proxyManager.Run(*listenStr); err != nil {
|
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||||
fmt.Printf("Server error: %v\n", err)
|
go func() {
|
||||||
os.Exit(1)
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("Fatal server error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for exit signal
|
||||||
|
<-exitChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func debounce(interval time.Duration, f func()) func() {
|
||||||
|
var timer *time.Timer
|
||||||
|
return func() {
|
||||||
|
if timer != nil {
|
||||||
|
timer.Stop()
|
||||||
|
}
|
||||||
|
timer = time.AfterFunc(interval, f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// created for issue: #252 https://github.com/mostlygeek/llama-swap/issues/252
|
||||||
|
// this simple benchmark tool sends a lot of small chat completion requests to llama-swap
|
||||||
|
// to make sure all the requests are accounted for.
|
||||||
|
//
|
||||||
|
// requests can be sent in parallel, and the tool will report the results.
|
||||||
|
// usage: go run main.go -baseurl http://localhost:8080/v1 -model llama3 -requests 1000 -par 5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// ----- CLI arguments ----------------------------------------------------
|
||||||
|
var (
|
||||||
|
baseurl string
|
||||||
|
modelName string
|
||||||
|
totalRequests int
|
||||||
|
parallelization int
|
||||||
|
)
|
||||||
|
|
||||||
|
flag.StringVar(&baseurl, "baseurl", "http://localhost:8080/v1", "Base URL of the API (e.g., https://api.example.com)")
|
||||||
|
flag.StringVar(&modelName, "model", "", "Model name to use")
|
||||||
|
flag.IntVar(&totalRequests, "requests", 1, "Total number of requests to send")
|
||||||
|
flag.IntVar(¶llelization, "par", 1, "Maximum number of concurrent requests")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if baseurl == "" || modelName == "" {
|
||||||
|
fmt.Println("Error: both -baseurl and -model are required.")
|
||||||
|
flag.Usage()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if totalRequests <= 0 {
|
||||||
|
fmt.Println("Error: -requests must be greater than 0.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if parallelization <= 0 {
|
||||||
|
fmt.Println("Error: -parallelization must be greater than 0.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- HTTP client -------------------------------------------------------
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Tracking response codes -------------------------------------------
|
||||||
|
statusCounts := make(map[int]int) // map[statusCode]count
|
||||||
|
var mu sync.Mutex // protects statusCounts
|
||||||
|
|
||||||
|
// ----- Request queue (buffered channel) ----------------------------------
|
||||||
|
requests := make(chan int, 10) // Buffered channel with capacity 10
|
||||||
|
|
||||||
|
// Goroutine to fill the request queue
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < totalRequests; i++ {
|
||||||
|
requests <- i + 1
|
||||||
|
}
|
||||||
|
close(requests)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// ----- Worker pool -------------------------------------------------------
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < parallelization; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(workerID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
for reqID := range requests {
|
||||||
|
// Build request payload as a single line JSON string
|
||||||
|
payload := `{"model":"` + modelName + `","max_tokens":100,"stream":false,"messages":[{"role":"user","content":"write a snake game in python"}]}`
|
||||||
|
|
||||||
|
// Send POST request
|
||||||
|
req, err := http.NewRequest(http.MethodPost,
|
||||||
|
fmt.Sprintf("%s/chat/completions", baseurl),
|
||||||
|
bytes.NewReader([]byte(payload)))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[worker %d][req %d] request creation error: %v", workerID, reqID, err)
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[-1]++
|
||||||
|
mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[worker %d][req %d] HTTP request error: %v", workerID, reqID, err)
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[-1]++
|
||||||
|
mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// Record status code
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[resp.StatusCode]++
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Status ticker (prints every second) -------------------------------
|
||||||
|
done := make(chan struct{})
|
||||||
|
tickerDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
startTime := time.Now()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
mu.Lock()
|
||||||
|
// Compute how many requests have completed so far
|
||||||
|
completed := 0
|
||||||
|
for _, cnt := range statusCounts {
|
||||||
|
completed += cnt
|
||||||
|
}
|
||||||
|
// Calculate duration and progress
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
progress := completed * 100 / totalRequests
|
||||||
|
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, progress)
|
||||||
|
mu.Unlock()
|
||||||
|
case <-done:
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, 100)
|
||||||
|
close(tickerDone)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for all workers to finish
|
||||||
|
wg.Wait()
|
||||||
|
close(done) // stops the status-update goroutine
|
||||||
|
<-tickerDone // give ticker time to finish / print
|
||||||
|
|
||||||
|
// ----- Summary ------------------------------------------------------------
|
||||||
|
fmt.Println("\n\n=== HTTP response code summary ===")
|
||||||
|
mu.Lock()
|
||||||
|
for code, cnt := range statusCounts {
|
||||||
|
if code == -1 {
|
||||||
|
fmt.Printf("Client-side errors (no HTTP response): %d\n", cnt)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%d : %d\n", code, cnt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
**
|
||||||
|
Test how exec.Cmd.CommandContext behaves under certain conditions:*
|
||||||
|
|
||||||
|
- process is killed externally, what happens with cmd.Wait() *
|
||||||
|
✔︎ it returns. catches crashes.*
|
||||||
|
|
||||||
|
- process ignores SIGTERM*
|
||||||
|
✔︎ `kill()` is called after cmd.WaitDelay*
|
||||||
|
|
||||||
|
- this process exits, what happens with children (kill -9 <this process' pid>)*
|
||||||
|
x they stick around. have to be manually killed.*
|
||||||
|
|
||||||
|
- .WithTimeout()'s cancel is called *
|
||||||
|
✔︎ process is killed after it ignores sigterm, cmd.Wait() catches it.*
|
||||||
|
|
||||||
|
- parent receives SIGINT/SIGTERM, what happens
|
||||||
|
✔︎ waits for child process to exit, then exits gracefully.
|
||||||
|
*/
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
// swap between these to use kill -9 <pid> on the cli to sim external crash
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
//ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
//cmd := exec.CommandContext(ctx, "sleep", "1")
|
||||||
|
cmd := exec.CommandContext(ctx,
|
||||||
|
"../../build/simple-responder_darwin_arm64",
|
||||||
|
//"-ignore-sig-term", /* so it doesn't exit on receiving SIGTERM, test cmd.WaitTimeout */
|
||||||
|
)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
// set a wait delay before signing sig kill
|
||||||
|
cmd.WaitDelay = 500 * time.Millisecond
|
||||||
|
cmd.Cancel = func() error {
|
||||||
|
fmt.Println("✔︎ Cancel() called, sending SIGTERM")
|
||||||
|
cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
|
||||||
|
//return nil
|
||||||
|
|
||||||
|
// this error is returned by cmd.Wait(), and can be used to
|
||||||
|
// single an error when the process couldn't be normally terminated
|
||||||
|
// but since a SIGTERM is sent, it's probably ok to return a nil
|
||||||
|
// as WaitDelay timing out will override the any error set here.
|
||||||
|
//
|
||||||
|
// test by enabling/disabling -ignore-sig-term on the process
|
||||||
|
// with -ignore-sig-term enabled, cmd.Wait() will have "signal: killed"
|
||||||
|
// without it, it will show the "new error from cancel"
|
||||||
|
return errors.New("error from cmd.Cancel()") // sets error returned by cmd.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
fmt.Println("Error starting process:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// catch signals. Calls cancel() which will cause cmd.Wait() to return and
|
||||||
|
// this program to eventually exit gracefully.
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
signal := <-sigChan
|
||||||
|
fmt.Printf("✔︎ Received signal: %d, Killing process... with cancel before exiting\n", signal)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Printf("✔︎ Parent Pid: %d, Process Pid: %d\n", os.Getpid(), cmd.Process.Pid)
|
||||||
|
fmt.Println("✔︎ Process started, cmd.Wait() ... ")
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
fmt.Println("✔︎ cmd.Wait returned, Error:", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("✔︎ cmd.Wait returned, Process exited on its own")
|
||||||
|
}
|
||||||
|
fmt.Println("✔︎ Child process exited, Done.")
|
||||||
|
}
|
||||||
@@ -26,6 +26,8 @@ func main() {
|
|||||||
|
|
||||||
silent := flag.Bool("silent", false, "disable all logging")
|
silent := flag.Bool("silent", false, "disable all logging")
|
||||||
|
|
||||||
|
ignoreSigTerm := flag.Bool("ignore-sig-term", false, "ignore SIGTERM signal")
|
||||||
|
|
||||||
flag.Parse() // Parse the command-line flags
|
flag.Parse() // Parse the command-line flags
|
||||||
|
|
||||||
// Create a new Gin router
|
// Create a new Gin router
|
||||||
@@ -33,14 +35,90 @@ func main() {
|
|||||||
|
|
||||||
// Set up the handler function using the provided response message
|
// Set up the handler function using the provided response message
|
||||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
|
|
||||||
// add a wait to simulate a slow query
|
// Check if streaming is requested
|
||||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
// Query is checked instead of JSON body since that event stream conflicts with other tests
|
||||||
time.Sleep(wait)
|
isStreaming := c.Query("stream") == "true"
|
||||||
|
|
||||||
|
if isStreaming {
|
||||||
|
// Set headers for streaming
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
// add a wait to simulate a slow query
|
||||||
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
|
time.Sleep(wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 10 "asdf" tokens
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := gin.H{
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"choices": []gin.H{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": gin.H{
|
||||||
|
"content": "asdf",
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.SSEvent("message", data)
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final data with usage info
|
||||||
|
finalData := gin.H{
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
// add timings to simulate llama.cpp
|
||||||
|
"timings": gin.H{
|
||||||
|
"prompt_n": 25,
|
||||||
|
"prompt_ms": 13,
|
||||||
|
"predicted_n": 10,
|
||||||
|
"predicted_ms": 17,
|
||||||
|
"predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.SSEvent("message", finalData)
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
// Send [DONE]
|
||||||
|
c.SSEvent("message", "[DONE]")
|
||||||
|
c.Writer.Flush()
|
||||||
|
} else {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// add a wait to simulate a slow query
|
||||||
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
|
time.Sleep(wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||||
|
"request_body": string(bodyBytes),
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
"timings": gin.H{
|
||||||
|
"prompt_n": 25,
|
||||||
|
"prompt_ms": 13,
|
||||||
|
"predicted_n": 10,
|
||||||
|
"predicted_ms": 17,
|
||||||
|
"predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
c.String(200, *responseMessage)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// for issue #62 to check model name strips profile slug
|
// for issue #62 to check model name strips profile slug
|
||||||
@@ -63,8 +141,29 @@ func main() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "application/json")
|
||||||
c.String(200, *responseMessage)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
// llama-server compatibility: /completion
|
||||||
|
r.POST("/completion", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// issue #41
|
// issue #41
|
||||||
@@ -104,6 +203,10 @@ func main() {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
||||||
|
// expose some header values for testing
|
||||||
|
"h_content_type": c.GetHeader("Content-Type"),
|
||||||
|
"h_content_length": c.GetHeader("Content-Length"),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -180,6 +283,10 @@ func main() {
|
|||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !*silent {
|
||||||
|
fmt.Printf("My PID: %d\n", os.Getpid())
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("simple-responder listening on %s\n", address)
|
log.Printf("simple-responder listening on %s\n", address)
|
||||||
// service connections
|
// service connections
|
||||||
@@ -190,11 +297,36 @@ func main() {
|
|||||||
|
|
||||||
// Wait for interrupt signal to gracefully shutdown the server with
|
// Wait for interrupt signal to gracefully shutdown the server with
|
||||||
// a timeout of 5 seconds.
|
// a timeout of 5 seconds.
|
||||||
quit := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
// kill (no param) default send syscall.SIGTERM
|
// kill (no param) default send syscall.SIGTERM
|
||||||
// kill -2 is syscall.SIGINT
|
// kill -2 is syscall.SIGINT
|
||||||
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
<-quit
|
|
||||||
|
countSigInt := 0
|
||||||
|
|
||||||
|
runloop:
|
||||||
|
for {
|
||||||
|
signal := <-sigChan
|
||||||
|
switch signal {
|
||||||
|
case syscall.SIGINT:
|
||||||
|
countSigInt++
|
||||||
|
if countSigInt > 1 {
|
||||||
|
break runloop
|
||||||
|
} else {
|
||||||
|
log.Println("Received SIGINT, send another SIGINT to shutdown")
|
||||||
|
}
|
||||||
|
case syscall.SIGTERM:
|
||||||
|
if *ignoreSigTerm {
|
||||||
|
log.Println("Ignoring SIGTERM")
|
||||||
|
} else {
|
||||||
|
log.Println("Received SIGTERM, shutting down")
|
||||||
|
break runloop
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break runloop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Println("simple-responder shutting down")
|
log.Println("simple-responder shutting down")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
ui_dist/*
|
||||||
@@ -2,15 +2,24 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/shlex"
|
"github.com/billziss-gh/golib/shlex"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DEFAULT_GROUP_ID = "(default)"
|
||||||
|
|
||||||
type ModelConfig struct {
|
type ModelConfig struct {
|
||||||
Cmd string `yaml:"cmd"`
|
Cmd string `yaml:"cmd"`
|
||||||
|
CmdStop string `yaml:"cmdStop"`
|
||||||
Proxy string `yaml:"proxy"`
|
Proxy string `yaml:"proxy"`
|
||||||
Aliases []string `yaml:"aliases"`
|
Aliases []string `yaml:"aliases"`
|
||||||
Env []string `yaml:"env"`
|
Env []string `yaml:"env"`
|
||||||
@@ -18,20 +27,145 @@ type ModelConfig struct {
|
|||||||
UnloadAfter int `yaml:"ttl"`
|
UnloadAfter int `yaml:"ttl"`
|
||||||
Unlisted bool `yaml:"unlisted"`
|
Unlisted bool `yaml:"unlisted"`
|
||||||
UseModelName string `yaml:"useModelName"`
|
UseModelName string `yaml:"useModelName"`
|
||||||
|
|
||||||
|
// #179 for /v1/models
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
|
||||||
|
// Limit concurrency of HTTP requests to process
|
||||||
|
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||||
|
|
||||||
|
// Model filters see issue #174
|
||||||
|
Filters ModelFilters `yaml:"filters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawModelConfig ModelConfig
|
||||||
|
defaults := rawModelConfig{
|
||||||
|
Cmd: "",
|
||||||
|
CmdStop: "",
|
||||||
|
Proxy: "http://localhost:${PORT}",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
UnloadAfter: 0,
|
||||||
|
Unlisted: false,
|
||||||
|
UseModelName: "",
|
||||||
|
ConcurrencyLimit: 0,
|
||||||
|
Name: "",
|
||||||
|
Description: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = ModelConfig(defaults)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
return SanitizeCommand(m.Cmd)
|
return SanitizeCommand(m.Cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelFilters see issue #174
|
||||||
|
type ModelFilters struct {
|
||||||
|
StripParams string `yaml:"strip_params"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawModelFilters ModelFilters
|
||||||
|
defaults := rawModelFilters{
|
||||||
|
StripParams: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = ModelFilters(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||||
|
if f.StripParams == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
params := strings.Split(f.StripParams, ",")
|
||||||
|
cleaned := make([]string, 0, len(params))
|
||||||
|
|
||||||
|
for _, param := range params {
|
||||||
|
trimmed := strings.TrimSpace(param)
|
||||||
|
if trimmed == "model" || trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cleaned = append(cleaned, trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort cleaned
|
||||||
|
slices.Sort(cleaned)
|
||||||
|
return cleaned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroupConfig struct {
|
||||||
|
Swap bool `yaml:"swap"`
|
||||||
|
Exclusive bool `yaml:"exclusive"`
|
||||||
|
Persistent bool `yaml:"persistent"`
|
||||||
|
Members []string `yaml:"members"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default values for GroupConfig
|
||||||
|
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawGroupConfig GroupConfig
|
||||||
|
defaults := rawGroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Persistent: false,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = GroupConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HooksConfig struct {
|
||||||
|
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HookOnStartup struct {
|
||||||
|
Preload []string `yaml:"preload"`
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
LogRequests bool `yaml:"logRequests"`
|
LogRequests bool `yaml:"logRequests"`
|
||||||
Models map[string]ModelConfig `yaml:"models"`
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||||
|
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||||
|
|
||||||
|
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||||
|
Macros map[string]string `yaml:"macros"`
|
||||||
|
|
||||||
// map aliases to actual model IDs
|
// map aliases to actual model IDs
|
||||||
aliases map[string]string
|
aliases map[string]string
|
||||||
|
|
||||||
|
// automatic port assignments
|
||||||
|
StartPort int `yaml:"startPort"`
|
||||||
|
|
||||||
|
// hooks, see: #209
|
||||||
|
Hooks HooksConfig `yaml:"hooks"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
@@ -52,42 +186,256 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig(path string) (*Config, error) {
|
func LoadConfig(path string) (Config, error) {
|
||||||
data, err := os.ReadFile(path)
|
file, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Config{}, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
return LoadConfigFromReader(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var config Config
|
// default configuration values
|
||||||
|
config := Config{
|
||||||
|
HealthCheckTimeout: 120,
|
||||||
|
StartPort: 5800,
|
||||||
|
LogLevel: "info",
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
}
|
||||||
err = yaml.Unmarshal(data, &config)
|
err = yaml.Unmarshal(data, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Config{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
if config.HealthCheckTimeout < 15 {
|
||||||
|
// set a minimum of 15 seconds
|
||||||
config.HealthCheckTimeout = 15
|
config.HealthCheckTimeout = 15
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
// Populate the aliases map
|
// Populate the aliases map
|
||||||
config.aliases = make(map[string]string)
|
config.aliases = make(map[string]string)
|
||||||
for modelName, modelConfig := range config.Models {
|
for modelName, modelConfig := range config.Models {
|
||||||
for _, alias := range modelConfig.Aliases {
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
config.aliases[alias] = modelName
|
config.aliases[alias] = modelName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &config, nil
|
/* check macro constraint rules:
|
||||||
|
|
||||||
|
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||||
|
- names must be less than 64 characters (no reason, just cause)
|
||||||
|
- name can not be any reserved macros: PORT, MODEL_ID
|
||||||
|
- macro values must be less than 1024 characters
|
||||||
|
*/
|
||||||
|
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
|
for macroName, macroValue := range config.Macros {
|
||||||
|
if len(macroName) >= 64 {
|
||||||
|
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
|
||||||
|
}
|
||||||
|
if !macroNameRegex.MatchString(macroName) {
|
||||||
|
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
|
||||||
|
}
|
||||||
|
if len(macroValue) >= 1024 {
|
||||||
|
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
|
||||||
|
}
|
||||||
|
switch macroName {
|
||||||
|
case "PORT":
|
||||||
|
case "MODEL_ID":
|
||||||
|
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and sort all model IDs first, makes testing more consistent
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||||
|
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
|
||||||
|
// Strip comments from command fields before macro expansion
|
||||||
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
|
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
||||||
|
for macroName, macroValue := range config.Macros {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// enforce ${PORT} used in both cmd and proxy
|
||||||
|
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
|
||||||
|
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||||
|
nextPortStr := strconv.Itoa(nextPort)
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(modelConfig.Cmd, "${MODEL_ID}") || strings.Contains(modelConfig.CmdStop, "${MODEL_ID}") {
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${MODEL_ID}", modelId)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${MODEL_ID}", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure there are no unknown macros that have not been replaced
|
||||||
|
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||||
|
fieldMap := map[string]string{
|
||||||
|
"cmd": modelConfig.Cmd,
|
||||||
|
"cmdStop": modelConfig.CmdStop,
|
||||||
|
"proxy": modelConfig.Proxy,
|
||||||
|
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, fieldValue := range fieldMap {
|
||||||
|
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
|
continue // this is ok, has to be replaced by process later
|
||||||
|
}
|
||||||
|
if _, exists := config.Macros[macroName]; !exists {
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
// check that members are all unique in the groups
|
||||||
|
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
// Check for duplicates within this group
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
// Check if member is used in another group
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clean up hooks preload
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
var toPreload []string
|
||||||
|
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if real, found := config.RealModelName(modelID); found {
|
||||||
|
toPreload = append(toPreload, real)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewrites the yaml to include a default group with any orphaned models
|
||||||
|
func AddDefaultGroupToConfig(config Config) Config {
|
||||||
|
|
||||||
|
if config.Groups == nil {
|
||||||
|
config.Groups = make(map[string]GroupConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultGroup := GroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
// if groups is empty, create a default group and put
|
||||||
|
// all models into it
|
||||||
|
if len(config.Groups) == 0 {
|
||||||
|
for modelName := range config.Models {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// iterate over existing group members and add non-grouped models into the default group
|
||||||
|
for modelName, _ := range config.Models {
|
||||||
|
foundModel := false
|
||||||
|
found:
|
||||||
|
// search for the model in existing groups
|
||||||
|
for _, groupConfig := range config.Groups {
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if member == modelName {
|
||||||
|
foundModel = true
|
||||||
|
break found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundModel {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||||
|
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||||
|
|
||||||
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
// Remove trailing backslashes
|
var cleanedLines []string
|
||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\ \n", " ")
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
cmdStr = strings.ReplaceAll(cmdStr, "\\\n", " ")
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle trailing backslashes by replacing with space
|
||||||
|
if strings.HasSuffix(trimmed, "\\") {
|
||||||
|
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put it back together
|
||||||
|
cmdStr = strings.Join(cleanedLines, "\n")
|
||||||
|
|
||||||
// Split the command into arguments
|
// Split the command into arguments
|
||||||
args, err := shlex.Split(cmdStr)
|
var args []string
|
||||||
if err != nil {
|
if runtime.GOOS == "windows" {
|
||||||
return nil, err
|
args = shlex.Windows.Split(cmdStr)
|
||||||
|
} else {
|
||||||
|
args = shlex.Posix.Split(cmdStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the command is not empty
|
// Ensure the command is not empty
|
||||||
@@ -97,3 +445,16 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
|
|||||||
|
|
||||||
return args, nil
|
return args, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func StripComments(cmdStr string) string {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
return strings.Join(cleanedLines, "\n")
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,242 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
|
// Test a command with spaces and newlines
|
||||||
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
|
-a "double quotes" \
|
||||||
|
--arg2 'single quotes'
|
||||||
|
-s
|
||||||
|
# comment 1
|
||||||
|
--arg3 123 \
|
||||||
|
|
||||||
|
# comment 2
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
|
||||||
|
|
||||||
|
# this will get stripped out as well as the white space above
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"--arg2", "single quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", `"string in string"`,
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
|
// Test an empty command
|
||||||
|
args, err = SanitizeCommand("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the default values are automatically set for global, model and group configurations
|
||||||
|
// after loading the configuration
|
||||||
|
func TestConfig_DefaultValuesPosix(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "info", config.LogLevel)
|
||||||
|
|
||||||
|
// Test default group exists
|
||||||
|
defaultGroup, exists := config.Groups["(default)"]
|
||||||
|
assert.True(t, exists, "default group should exist")
|
||||||
|
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||||
|
assert.Equal(t, true, defaultGroup.Swap)
|
||||||
|
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||||
|
assert.Equal(t, false, defaultGroup.Persistent)
|
||||||
|
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||||
|
}
|
||||||
|
|
||||||
|
model1, exists := config.Models["model1"]
|
||||||
|
assert.True(t, exists, "model1 should exist")
|
||||||
|
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||||
|
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||||
|
assert.Equal(t, "", model1.CmdStop)
|
||||||
|
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||||
|
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||||
|
assert.Equal(t, []string{}, model1.Aliases)
|
||||||
|
assert.Equal(t, []string{}, model1.Env)
|
||||||
|
assert.Equal(t, 0, model1.UnloadAfter)
|
||||||
|
assert.Equal(t, false, model1.Unlisted)
|
||||||
|
assert.Equal(t, "", model1.UseModelName)
|
||||||
|
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// default empty filter exists
|
||||||
|
assert.Equal(t, "", model1.Filters.StripParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_LoadPosix(t *testing.T) {
|
||||||
|
// Create a temporary YAML file for testing
|
||||||
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
hooks:
|
||||||
|
on_startup:
|
||||||
|
preload: ["model1", "model2"]
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
name: "Model 1"
|
||||||
|
description: "This is model 1"
|
||||||
|
aliases:
|
||||||
|
- "m1"
|
||||||
|
- "model-one"
|
||||||
|
env:
|
||||||
|
- "VAR1=value1"
|
||||||
|
- "VAR2=value2"
|
||||||
|
checkEndpoint: "/health"
|
||||||
|
model2:
|
||||||
|
cmd: ${svr-path} --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "m2"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "mthree"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model4:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8082"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
profiles:
|
||||||
|
test:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
forever:
|
||||||
|
exclusive: false
|
||||||
|
persistent: true
|
||||||
|
members:
|
||||||
|
- "model4"
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write temporary file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the config and verify
|
||||||
|
config, err := LoadConfig(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := Config{
|
||||||
|
LogLevel: "info",
|
||||||
|
StartPort: 5800,
|
||||||
|
Macros: map[string]string{
|
||||||
|
"svr-path": "path/to/server",
|
||||||
|
},
|
||||||
|
Hooks: HooksConfig{
|
||||||
|
OnStartup: HookOnStartup{
|
||||||
|
Preload: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
Aliases: []string{"m1", "model-one"},
|
||||||
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
Name: "Model 1",
|
||||||
|
Description: "This is model 1",
|
||||||
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "path/to/server --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model3": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"mthree"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model4": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8082",
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
"mthree": "model3",
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, config)
|
||||||
|
|
||||||
|
realname, found := config.RealModelName("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", realname)
|
||||||
|
}
|
||||||
@@ -1,90 +1,68 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"slices"
|
||||||
"path/filepath"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||||
// Create a temporary YAML file for testing
|
content := `
|
||||||
tempDir, err := os.MkdirTemp("", "test-config")
|
models:
|
||||||
if err != nil {
|
model1:
|
||||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
cmd: path/to/cmd --arg1 one
|
||||||
}
|
proxy: "http://localhost:8080"
|
||||||
defer os.RemoveAll(tempDir)
|
model2:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
healthCheckTimeout: 15
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
group2:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
`
|
||||||
|
// Load the config and verify
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
|
||||||
|
// a Contains as order of the map is not guaranteed
|
||||||
|
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
||||||
content := `
|
content := `
|
||||||
models:
|
models:
|
||||||
model1:
|
model1:
|
||||||
cmd: path/to/cmd --arg1 one
|
cmd: path/to/cmd --arg1 one
|
||||||
proxy: "http://localhost:8080"
|
proxy: "http://localhost:8080"
|
||||||
aliases:
|
aliases:
|
||||||
- "m1"
|
- m1
|
||||||
- "model-one"
|
|
||||||
env:
|
|
||||||
- "VAR1=value1"
|
|
||||||
- "VAR2=value2"
|
|
||||||
checkEndpoint: "/health"
|
|
||||||
model2:
|
model2:
|
||||||
cmd: path/to/cmd --arg1 one
|
cmd: path/to/cmd --arg1 one
|
||||||
proxy: "http://localhost:8081"
|
proxy: "http://localhost:8081"
|
||||||
aliases:
|
|
||||||
- "m2"
|
|
||||||
checkEndpoint: "/"
|
checkEndpoint: "/"
|
||||||
healthCheckTimeout: 15
|
aliases:
|
||||||
profiles:
|
- m1
|
||||||
test:
|
- m2
|
||||||
- model1
|
|
||||||
- model2
|
|
||||||
`
|
`
|
||||||
|
|
||||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write temporary file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the config and verify
|
// Load the config and verify
|
||||||
config, err := LoadConfig(tempFile)
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := &Config{
|
// this is a contains because it could be `model1` or `model2` depending on the order
|
||||||
Models: map[string]ModelConfig{
|
// go decided on the order of the map
|
||||||
"model1": {
|
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8080",
|
|
||||||
Aliases: []string{"m1", "model-one"},
|
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
},
|
|
||||||
"model2": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8081",
|
|
||||||
Aliases: []string{"m2"},
|
|
||||||
Env: nil,
|
|
||||||
CheckEndpoint: "/",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"model1", "model2"},
|
|
||||||
},
|
|
||||||
aliases: map[string]string{
|
|
||||||
"m1": "model1",
|
|
||||||
"model-one": "model1",
|
|
||||||
"m2": "model2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, expected, config)
|
|
||||||
|
|
||||||
realname, found := config.RealModelName("m1")
|
|
||||||
assert.True(t, found)
|
|
||||||
assert.Equal(t, "model1", realname)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
@@ -147,30 +125,359 @@ func TestConfig_FindConfig(t *testing.T) {
|
|||||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
||||||
|
|
||||||
// Test a command with spaces and newlines
|
t.Run("Default Port Ranges", func(t *testing.T) {
|
||||||
args, err := SanitizeCommand(`python model1.py \
|
content := ``
|
||||||
-a "double quotes" \
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
--arg2 'single quotes'
|
if !assert.NoError(t, err) {
|
||||||
-s
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
--arg3 123 \
|
}
|
||||||
--arg4 '"string in string"'
|
|
||||||
-c "'single quoted'"
|
|
||||||
`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{
|
|
||||||
"python", "model1.py",
|
|
||||||
"-a", "double quotes",
|
|
||||||
"--arg2", "single quotes",
|
|
||||||
"-s",
|
|
||||||
"--arg3", "123",
|
|
||||||
"--arg4", `"string in string"`,
|
|
||||||
"-c", `'single quoted'`,
|
|
||||||
}, args)
|
|
||||||
|
|
||||||
// Test an empty command
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
args, err = SanitizeCommand("")
|
})
|
||||||
assert.Error(t, err)
|
t.Run("User specific port ranges", func(t *testing.T) {
|
||||||
assert.Nil(t, args)
|
content := `startPort: 1000`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 1000, config.StartPort)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Invalid start port", func(t *testing.T) {
|
||||||
|
content := `startPort: abcd`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("start port must be greater than 1", func(t *testing.T) {
|
||||||
|
content := `startPort: -99`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Automatic port assignments", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 5800
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: svr --port ${PORT}
|
||||||
|
model2:
|
||||||
|
cmd: svr --port ${PORT}
|
||||||
|
proxy: "http://172.11.22.33:${PORT}"
|
||||||
|
model3:
|
||||||
|
cmd: svr --port 1999
|
||||||
|
proxy: "http://1.2.3.4:1999"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
||||||
|
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
||||||
|
|
||||||
|
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
||||||
|
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
||||||
|
|
||||||
|
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
||||||
|
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: svr --port 111
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Equal(t, "model model1: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroReplacement(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
argOne: "--arg1"
|
||||||
|
argTwo: "--arg2"
|
||||||
|
autoPort: "--port ${PORT}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: |
|
||||||
|
${svr-path} ${argTwo}
|
||||||
|
# the automatic ${PORT} is replaced
|
||||||
|
${autoPort}
|
||||||
|
${argOne}
|
||||||
|
--arg3 three
|
||||||
|
cmdStop: |
|
||||||
|
/path/to/stop.sh --port ${PORT} ${argTwo}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " "))
|
||||||
|
|
||||||
|
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
field string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unknown macro in cmd",
|
||||||
|
field: "cmd",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: |
|
||||||
|
${svr-path} --port ${PORT}
|
||||||
|
${unknownMacro}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown macro in cmdStop",
|
||||||
|
field: "cmdStop",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
|
cmdStop: "kill ${unknownMacro}"
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown macro in proxy",
|
||||||
|
field: "proxy",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
|
proxy: "http://localhost:${unknownMacro}"
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown macro in checkEndpoint",
|
||||||
|
field: "checkEndpoint",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
|
checkEndpoint: "http://localhost:${unknownMacro}/health"
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||||
|
//t.Log(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelFilters(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
default_strip: "temperature, top_p"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
strip_params: "model, top_k, ${default_strip}, , ,"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
modelConfig, ok := config.Models["model1"]
|
||||||
|
if !assert.True(t, ok) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure `model` and enmpty strings are not in the list
|
||||||
|
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||||
|
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripComments(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no comments",
|
||||||
|
input: "echo hello\necho world",
|
||||||
|
expected: "echo hello\necho world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single comment line",
|
||||||
|
input: "# this is a comment\necho hello",
|
||||||
|
expected: "echo hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple comment lines",
|
||||||
|
input: "# comment 1\necho hello\n# comment 2\necho world",
|
||||||
|
expected: "echo hello\necho world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comment with spaces",
|
||||||
|
input: " # indented comment\necho hello",
|
||||||
|
expected: "echo hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty lines preserved",
|
||||||
|
input: "echo hello\n\necho world",
|
||||||
|
expected: "echo hello\n\necho world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only comments",
|
||||||
|
input: "# comment 1\n# comment 2",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := StripComments(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("StripComments() = %q, expected %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroInCommentStrippedBeforeExpansion(t *testing.T) {
|
||||||
|
// Test case that reproduces the original bug where a macro in a comment
|
||||||
|
// would get expanded and cause the comment text to be included in the command
|
||||||
|
content := `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
"latest-llama": >
|
||||||
|
/user/llama.cpp/build/bin/llama-server
|
||||||
|
--port ${PORT}
|
||||||
|
|
||||||
|
models:
|
||||||
|
"test-model":
|
||||||
|
cmd: |
|
||||||
|
# ${latest-llama} is a macro that is defined above
|
||||||
|
${latest-llama}
|
||||||
|
--model /path/to/model.gguf
|
||||||
|
-ngl 99
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Get the sanitized command
|
||||||
|
sanitizedCmd, err := SanitizeCommand(config.Models["test-model"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Join the command for easier inspection
|
||||||
|
cmdStr := strings.Join(sanitizedCmd, " ")
|
||||||
|
|
||||||
|
// Verify that comment text is NOT present in the final command as separate arguments
|
||||||
|
commentWords := []string{"is", "macro", "that", "defined", "above"}
|
||||||
|
for _, word := range commentWords {
|
||||||
|
found := slices.Contains(sanitizedCmd, word)
|
||||||
|
assert.False(t, found, "Comment text '%s' should not be present as a separate argument in final command", word)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the actual command components ARE present
|
||||||
|
expectedParts := []string{
|
||||||
|
"/user/llama.cpp/build/bin/llama-server",
|
||||||
|
"--port",
|
||||||
|
"9990",
|
||||||
|
"--model",
|
||||||
|
"/path/to/model.gguf",
|
||||||
|
"-ngl",
|
||||||
|
"99",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range expectedParts {
|
||||||
|
assert.Contains(t, cmdStr, part, "Expected command part '%s' not found in final command", part)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the server path appears exactly once (not duplicated due to macro expansion)
|
||||||
|
serverPath := "/user/llama.cpp/build/bin/llama-server"
|
||||||
|
count := strings.Count(cmdStr, serverPath)
|
||||||
|
assert.Equal(t, 1, count, "Expected exactly 1 occurrence of server path, found %d", count)
|
||||||
|
|
||||||
|
// Verify the expected final command structure
|
||||||
|
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
|
||||||
|
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroModelId(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 9000
|
||||||
|
macros:
|
||||||
|
"docker-llama": docker run --name ${MODEL_ID} -p ${PORT}:8080 docker_img
|
||||||
|
"docker-stop": docker stop ${MODEL_ID}
|
||||||
|
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
|
||||||
|
|
||||||
|
model2:
|
||||||
|
cmd: ${docker-llama}
|
||||||
|
cmdStop: ${docker-stop}
|
||||||
|
|
||||||
|
author/model:F16:
|
||||||
|
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
|
||||||
|
cmdStop: stop
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
|
||||||
|
|
||||||
|
assert.Equal(t, "docker stop ${MODEL_ID}", config.Macros["docker-stop"])
|
||||||
|
|
||||||
|
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "docker run --name model2 -p 9002:8080 docker_img", strings.Join(sanitizedCmd2, " "))
|
||||||
|
|
||||||
|
sanitizedCmdStop, err := SanitizeCommand(config.Models["model2"].CmdStop)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "docker stop model2", strings.Join(sanitizedCmdStop, " "))
|
||||||
|
|
||||||
|
sanitizedCmd3, err := SanitizeCommand(config.Models["author/model:F16"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,231 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
|
// does not support single quoted strings like in config_posix_test.go
|
||||||
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
|
|
||||||
|
-a "double quotes" \
|
||||||
|
-s
|
||||||
|
--arg3 123 \
|
||||||
|
|
||||||
|
# comment 2
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# this will get stripped out as well as the white space above
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
|
// Test an empty command
|
||||||
|
args, err = SanitizeCommand("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_DefaultValuesWindows(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "info", config.LogLevel)
|
||||||
|
|
||||||
|
// Test default group exists
|
||||||
|
defaultGroup, exists := config.Groups["(default)"]
|
||||||
|
assert.True(t, exists, "default group should exist")
|
||||||
|
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||||
|
assert.Equal(t, true, defaultGroup.Swap)
|
||||||
|
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||||
|
assert.Equal(t, false, defaultGroup.Persistent)
|
||||||
|
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||||
|
}
|
||||||
|
|
||||||
|
model1, exists := config.Models["model1"]
|
||||||
|
assert.True(t, exists, "model1 should exist")
|
||||||
|
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||||
|
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||||
|
assert.Equal(t, "taskkill /f /t /pid ${PID}", model1.CmdStop)
|
||||||
|
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||||
|
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||||
|
assert.Equal(t, []string{}, model1.Aliases)
|
||||||
|
assert.Equal(t, []string{}, model1.Env)
|
||||||
|
assert.Equal(t, 0, model1.UnloadAfter)
|
||||||
|
assert.Equal(t, false, model1.Unlisted)
|
||||||
|
assert.Equal(t, "", model1.UseModelName)
|
||||||
|
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// default empty filter exists
|
||||||
|
assert.Equal(t, "", model1.Filters.StripParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_LoadWindows(t *testing.T) {
|
||||||
|
// Create a temporary YAML file for testing
|
||||||
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
aliases:
|
||||||
|
- "m1"
|
||||||
|
- "model-one"
|
||||||
|
env:
|
||||||
|
- "VAR1=value1"
|
||||||
|
- "VAR2=value2"
|
||||||
|
checkEndpoint: "/health"
|
||||||
|
model2:
|
||||||
|
cmd: ${svr-path} --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "m2"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "mthree"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model4:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8082"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
profiles:
|
||||||
|
test:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
forever:
|
||||||
|
exclusive: false
|
||||||
|
persistent: true
|
||||||
|
members:
|
||||||
|
- "model4"
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write temporary file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the config and verify
|
||||||
|
config, err := LoadConfig(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := Config{
|
||||||
|
LogLevel: "info",
|
||||||
|
StartPort: 5800,
|
||||||
|
Macros: map[string]string{
|
||||||
|
"svr-path": "path/to/server",
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
Aliases: []string{"m1", "model-one"},
|
||||||
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "path/to/server --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model3": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"mthree"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model4": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8082",
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
"mthree": "model3",
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, config)
|
||||||
|
|
||||||
|
realname, found := config.RealModelName("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", realname)
|
||||||
|
}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||||
|
type DiscardWriter struct {
|
||||||
|
header http.Header
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) Header() http.Header {
|
||||||
|
if w.header == nil {
|
||||||
|
w.header = make(http.Header)
|
||||||
|
}
|
||||||
|
return w.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) WriteHeader(code int) {
|
||||||
|
w.status = code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Satisfy the http.Flusher interface for streaming responses
|
||||||
|
func (w *DiscardWriter) Flush() {}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
// package level registry of the different event types
|
||||||
|
|
||||||
|
const ProcessStateChangeEventID = 0x01
|
||||||
|
const ChatCompletionStatsEventID = 0x02
|
||||||
|
const ConfigFileChangedEventID = 0x03
|
||||||
|
const LogDataEventID = 0x04
|
||||||
|
const TokenMetricsEventID = 0x05
|
||||||
|
const ModelPreloadedEventID = 0x06
|
||||||
|
|
||||||
|
type ProcessStateChangeEvent struct {
|
||||||
|
ProcessName string
|
||||||
|
NewState ProcessState
|
||||||
|
OldState ProcessState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||||
|
return ProcessStateChangeEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStats struct {
|
||||||
|
TokensGenerated int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ChatCompletionStats) Type() uint32 {
|
||||||
|
return ChatCompletionStatsEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReloadingState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReloadingStateStart ReloadingState = iota
|
||||||
|
ReloadingStateEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfigFileChangedEvent struct {
|
||||||
|
ReloadingState ReloadingState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||||
|
return ConfigFileChangedEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogDataEvent struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e LogDataEvent) Type() uint32 {
|
||||||
|
return LogDataEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelPreloadedEvent struct {
|
||||||
|
ModelName string
|
||||||
|
Success bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ModelPreloadedEvent) Type() uint32 {
|
||||||
|
return ModelPreloadedEventID
|
||||||
|
}
|
||||||
@@ -9,11 +9,14 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
nextTestPort int = 12000
|
nextTestPort int = 12000
|
||||||
portMutex sync.Mutex
|
portMutex sync.Mutex
|
||||||
|
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
simpleResponderPath = getSimpleResponderPath()
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check if the binary exists
|
// Check if the binary exists
|
||||||
@@ -26,6 +29,17 @@ func TestMain(m *testing.M) {
|
|||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
switch os.Getenv("LOG_LEVEL") {
|
||||||
|
case "debug":
|
||||||
|
testLogger.SetLogLevel(LevelDebug)
|
||||||
|
case "warn":
|
||||||
|
testLogger.SetLogLevel(LevelWarn)
|
||||||
|
case "info":
|
||||||
|
testLogger.SetLogLevel(LevelInfo)
|
||||||
|
default:
|
||||||
|
testLogger.SetLogLevel(LevelWarn)
|
||||||
|
}
|
||||||
|
|
||||||
m.Run()
|
m.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,26 +47,39 @@ func TestMain(m *testing.M) {
|
|||||||
func getSimpleResponderPath() string {
|
func getSimpleResponderPath() string {
|
||||||
goos := runtime.GOOS
|
goos := runtime.GOOS
|
||||||
goarch := runtime.GOARCH
|
goarch := runtime.GOARCH
|
||||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
|
||||||
|
if goos == "windows" {
|
||||||
|
return filepath.Join("..", "build", "simple-responder.exe")
|
||||||
|
} else {
|
||||||
|
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
func getTestPort() int {
|
||||||
portMutex.Lock()
|
portMutex.Lock()
|
||||||
defer portMutex.Unlock()
|
defer portMutex.Unlock()
|
||||||
|
|
||||||
port := nextTestPort
|
port := nextTestPort
|
||||||
nextTestPort++
|
nextTestPort++
|
||||||
|
|
||||||
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
return port
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||||
|
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||||
binaryPath := getSimpleResponderPath()
|
// Create a YAML string with just the values we want to set
|
||||||
|
yamlStr := fmt.Sprintf(`
|
||||||
|
cmd: '%s --port %d --silent --respond %s'
|
||||||
|
proxy: "http://127.0.0.1:%d"
|
||||||
|
`, simpleResponderPath, port, expectedMessage, port)
|
||||||
|
|
||||||
// Create a process configuration
|
var cfg ModelConfig
|
||||||
return ModelConfig{
|
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||||
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
}
|
}
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 15 KiB |
@@ -1,14 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>llama-swap</title>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<h1>llama-swap</h1>
|
|
||||||
<p>
|
|
||||||
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
|
||||||
</p>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Logs</title>
|
|
||||||
<style>
|
|
||||||
body {
|
|
||||||
margin: 0;
|
|
||||||
height: 100vh;
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
font-family: "Courier New", Courier, monospace;
|
|
||||||
}
|
|
||||||
#log-controls {
|
|
||||||
margin: 0.5em;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: space-between; /* Spaces out elements evenly */
|
|
||||||
}
|
|
||||||
#log-controls input {
|
|
||||||
flex: 1;
|
|
||||||
}
|
|
||||||
#log-controls input:focus {
|
|
||||||
outline: none; /* Ensures no outline is shown when the input is focused */
|
|
||||||
}
|
|
||||||
#log-stream {
|
|
||||||
flex: 1;
|
|
||||||
margin: 0.5em;
|
|
||||||
padding: 1em;
|
|
||||||
background: #f4f4f4;
|
|
||||||
overflow-y: auto;
|
|
||||||
white-space: pre-wrap; /* Ensures line wrapping */
|
|
||||||
word-wrap: break-word; /* Ensures long words wrap */
|
|
||||||
}
|
|
||||||
|
|
||||||
.regex-error {
|
|
||||||
background-color: #ff0000 !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Dark mode styles */
|
|
||||||
@media (prefers-color-scheme: dark) {
|
|
||||||
body {
|
|
||||||
background-color: #333;
|
|
||||||
color: #fff;
|
|
||||||
}
|
|
||||||
|
|
||||||
#log-stream {
|
|
||||||
background: #444;
|
|
||||||
color: #fff;
|
|
||||||
}
|
|
||||||
|
|
||||||
#log-controls input {
|
|
||||||
background: #555;
|
|
||||||
color: #fff;
|
|
||||||
border: 1px solid #777;
|
|
||||||
}
|
|
||||||
|
|
||||||
#log-controls button {
|
|
||||||
background: #555;
|
|
||||||
color: #fff;
|
|
||||||
border: 1px solid #777;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<pre id="log-stream">Waiting for logs...</pre>
|
|
||||||
<div id="log-controls">
|
|
||||||
<input type="text" id="filter-input" placeholder="regex filter">
|
|
||||||
<button id="clear-button">clear</button>
|
|
||||||
</div>
|
|
||||||
<script>
|
|
||||||
const logStream = document.getElementById('log-stream');
|
|
||||||
const filterInput = document.getElementById('filter-input');
|
|
||||||
var logData = "";
|
|
||||||
let regexFilter = null;
|
|
||||||
|
|
||||||
function setupEventSource() {
|
|
||||||
if (typeof(EventSource) !== "undefined") {
|
|
||||||
const eventSource = new EventSource("/logs/streamSSE");
|
|
||||||
|
|
||||||
eventSource.onmessage = function(event) {
|
|
||||||
logData += event.data;
|
|
||||||
render()
|
|
||||||
};
|
|
||||||
|
|
||||||
eventSource.onerror = function(err) {
|
|
||||||
logData = "EventSource failed: " + err.message;
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
logData = "SSE Not supported by this browser."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// poor-ai's react ¯\_(ツ)_/¯
|
|
||||||
function render() {
|
|
||||||
if (regexFilter) {
|
|
||||||
const lines = logData.split('\n');
|
|
||||||
const filteredLines = lines.filter(line => {
|
|
||||||
return regexFilter === null || regexFilter.test(line);
|
|
||||||
});
|
|
||||||
|
|
||||||
if (filteredLines.length > 0) {
|
|
||||||
logStream.textContent = filteredLines.join('\n') + '\n';
|
|
||||||
} else {
|
|
||||||
logStream.textContent = "";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logStream.textContent = logData;
|
|
||||||
}
|
|
||||||
|
|
||||||
logStream.scrollTop = logStream.scrollHeight;
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateFilter() {
|
|
||||||
const pattern = filterInput.value.trim();
|
|
||||||
filterInput.classList.remove('regex-error');
|
|
||||||
if (pattern) {
|
|
||||||
try {
|
|
||||||
regexFilter = new RegExp(pattern);
|
|
||||||
} catch (e) {
|
|
||||||
console.error("Invalid regex pattern:", e);
|
|
||||||
regexFilter = null;
|
|
||||||
filterInput.classList.add('regex-error');
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
regexFilter = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
render();
|
|
||||||
}
|
|
||||||
|
|
||||||
filterInput.addEventListener('input', updateFilter);
|
|
||||||
document.getElementById('clear-button').addEventListener('click', () => {
|
|
||||||
filterInput.value = "";
|
|
||||||
regexFilter = null;
|
|
||||||
render();
|
|
||||||
});
|
|
||||||
setupEventSource();
|
|
||||||
updateFilter();
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import "embed"
|
|
||||||
|
|
||||||
//go:embed html
|
|
||||||
var htmlFiles embed.FS
|
|
||||||
|
|
||||||
func getHTMLFile(path string) ([]byte, error) {
|
|
||||||
return htmlFiles.ReadFile("html/" + path)
|
|
||||||
}
|
|
||||||
@@ -2,19 +2,36 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"container/ring"
|
"container/ring"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LogLevel int
|
||||||
|
|
||||||
|
const (
|
||||||
|
LevelDebug LogLevel = iota
|
||||||
|
LevelInfo
|
||||||
|
LevelWarn
|
||||||
|
LevelError
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogMonitor struct {
|
type LogMonitor struct {
|
||||||
clients map[chan []byte]bool
|
eventbus *event.Dispatcher
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
buffer *ring.Ring
|
buffer *ring.Ring
|
||||||
bufferMu sync.RWMutex
|
bufferMu sync.RWMutex
|
||||||
|
|
||||||
// typically this can be os.Stdout
|
// typically this can be os.Stdout
|
||||||
stdout io.Writer
|
stdout io.Writer
|
||||||
|
|
||||||
|
// logging levels
|
||||||
|
level LogLevel
|
||||||
|
prefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLogMonitor() *LogMonitor {
|
func NewLogMonitor() *LogMonitor {
|
||||||
@@ -23,9 +40,11 @@ func NewLogMonitor() *LogMonitor {
|
|||||||
|
|
||||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||||
return &LogMonitor{
|
return &LogMonitor{
|
||||||
clients: make(map[chan []byte]bool),
|
eventbus: event.NewDispatcherConfig(1000),
|
||||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||||
stdout: stdout,
|
stdout: stdout,
|
||||||
|
level: LevelInfo,
|
||||||
|
prefix: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,32 +84,86 @@ func (w *LogMonitor) GetHistory() []byte {
|
|||||||
return history
|
return history
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) Subscribe() chan []byte {
|
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||||
w.mu.Lock()
|
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
|
||||||
defer w.mu.Unlock()
|
callback(e.Data)
|
||||||
|
})
|
||||||
ch := make(chan []byte, 100)
|
|
||||||
w.clients[ch] = true
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
delete(w.clients, ch)
|
|
||||||
close(ch)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) broadcast(msg []byte) {
|
func (w *LogMonitor) broadcast(msg []byte) {
|
||||||
w.mu.RLock()
|
event.Publish(w.eventbus, LogDataEvent{Data: msg})
|
||||||
defer w.mu.RUnlock()
|
}
|
||||||
|
|
||||||
for client := range w.clients {
|
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||||
select {
|
w.mu.Lock()
|
||||||
case client <- msg:
|
defer w.mu.Unlock()
|
||||||
default:
|
w.prefix = prefix
|
||||||
// If client buffer is full, skip
|
}
|
||||||
}
|
|
||||||
|
func (w *LogMonitor) SetLogLevel(level LogLevel) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
w.level = level
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) formatMessage(level string, msg string) []byte {
|
||||||
|
prefix := ""
|
||||||
|
if w.prefix != "" {
|
||||||
|
prefix = fmt.Sprintf("[%s] ", w.prefix)
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) log(level LogLevel, msg string) {
|
||||||
|
if level < w.level {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(w.formatMessage(level.String(), msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Debug(msg string) {
|
||||||
|
w.log(LevelDebug, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Info(msg string) {
|
||||||
|
w.log(LevelInfo, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Warn(msg string) {
|
||||||
|
w.log(LevelWarn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Error(msg string) {
|
||||||
|
w.log(LevelError, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Debugf(format string, args ...interface{}) {
|
||||||
|
w.log(LevelDebug, fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Infof(format string, args ...interface{}) {
|
||||||
|
w.log(LevelInfo, fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Warnf(format string, args ...interface{}) {
|
||||||
|
w.log(LevelWarn, fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LogMonitor) Errorf(format string, args ...interface{}) {
|
||||||
|
w.log(LevelError, fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l LogLevel) String() string {
|
||||||
|
switch l {
|
||||||
|
case LevelDebug:
|
||||||
|
return "DEBUG"
|
||||||
|
case LevelInfo:
|
||||||
|
return "INFO"
|
||||||
|
case LevelWarn:
|
||||||
|
return "WARN"
|
||||||
|
case LevelError:
|
||||||
|
return "ERROR"
|
||||||
|
default:
|
||||||
|
return "UNKNOWN"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,38 +10,29 @@ import (
|
|||||||
func TestLogMonitor(t *testing.T) {
|
func TestLogMonitor(t *testing.T) {
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
// Test subscription
|
// A WaitGroup is used to wait for all the expected writes to complete
|
||||||
client1 := logMonitor.Subscribe()
|
var wg sync.WaitGroup
|
||||||
client2 := logMonitor.Subscribe()
|
|
||||||
|
|
||||||
defer logMonitor.Unsubscribe(client1)
|
|
||||||
defer logMonitor.Unsubscribe(client2)
|
|
||||||
|
|
||||||
client1Messages := make([]byte, 0)
|
client1Messages := make([]byte, 0)
|
||||||
client2Messages := make([]byte, 0)
|
client2Messages := make([]byte, 0)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
defer logMonitor.OnLogData(func(data []byte) {
|
||||||
wg.Add(1)
|
client1Messages = append(client1Messages, data...)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
go func() {
|
defer logMonitor.OnLogData(func(data []byte) {
|
||||||
defer wg.Done()
|
client2Messages = append(client2Messages, data...)
|
||||||
for {
|
wg.Done()
|
||||||
select {
|
})()
|
||||||
case data := <-client1:
|
|
||||||
client1Messages = append(client1Messages, data...)
|
wg.Add(6) // 2 x 3 writes
|
||||||
case data := <-client2:
|
|
||||||
client2Messages = append(client2Messages, data...)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
logMonitor.Write([]byte("1"))
|
logMonitor.Write([]byte("1"))
|
||||||
logMonitor.Write([]byte("2"))
|
logMonitor.Write([]byte("2"))
|
||||||
logMonitor.Write([]byte("3"))
|
logMonitor.Write([]byte("3"))
|
||||||
|
|
||||||
// Wait for the goroutine to finish
|
// wait for all writes to complete
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// Check the buffer
|
// Check the buffer
|
||||||
|
|||||||
@@ -0,0 +1,179 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MetricsRecorder struct {
|
||||||
|
metricsMonitor *MetricsMonitor
|
||||||
|
realModelName string
|
||||||
|
// isStreaming bool
|
||||||
|
startTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
||||||
|
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
|
if requestedModel == "" {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := &MetricsResponseWriter{
|
||||||
|
ResponseWriter: c.Writer,
|
||||||
|
metricsRecorder: &MetricsRecorder{
|
||||||
|
metricsMonitor: pm.metricsMonitor,
|
||||||
|
realModelName: realModelName,
|
||||||
|
startTime: time.Now(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.Writer = writer
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
// check for streaming response
|
||||||
|
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
|
||||||
|
writer.metricsRecorder.processStreamingResponse(writer.body)
|
||||||
|
} else {
|
||||||
|
writer.metricsRecorder.processNonStreamingResponse(writer.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
|
||||||
|
usage := jsonData.Get("usage")
|
||||||
|
timings := jsonData.Get("timings")
|
||||||
|
if !usage.Exists() && !timings.Exists() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// default values
|
||||||
|
outputTokens := 0
|
||||||
|
inputTokens := 0
|
||||||
|
|
||||||
|
// timings data
|
||||||
|
tokensPerSecond := -1.0
|
||||||
|
promptPerSecond := -1.0
|
||||||
|
durationMs := int(time.Since(rec.startTime).Milliseconds())
|
||||||
|
|
||||||
|
if usage.Exists() {
|
||||||
|
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
||||||
|
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||||
|
if timings.Exists() {
|
||||||
|
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
||||||
|
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
||||||
|
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
||||||
|
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||||
|
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||||
|
}
|
||||||
|
|
||||||
|
rec.metricsMonitor.addMetrics(TokenMetrics{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: rec.realModelName,
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
PromptPerSecond: promptPerSecond,
|
||||||
|
TokensPerSecond: tokensPerSecond,
|
||||||
|
DurationMs: durationMs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
||||||
|
// Iterate **backwards** through the lines looking for the data payload with
|
||||||
|
// usage data
|
||||||
|
lines := bytes.Split(body, []byte("\n"))
|
||||||
|
|
||||||
|
for i := len(lines) - 1; i >= 0; i-- {
|
||||||
|
line := bytes.TrimSpace(lines[i])
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE payload always follows "data:"
|
||||||
|
prefix := []byte("data:")
|
||||||
|
if !bytes.HasPrefix(line, prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[len(prefix):])
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
// [DONE] line itself contains nothing of interest.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.ValidBytes(data) {
|
||||||
|
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
|
||||||
|
return // short circuit if a metric was recorded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON to extract usage information
|
||||||
|
if gjson.ValidBytes(body) {
|
||||||
|
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsResponseWriter captures the entire response for non-streaming
|
||||||
|
type MetricsResponseWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body []byte
|
||||||
|
metricsRecorder *MetricsRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
n, err := w.ResponseWriter.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
w.body = append(w.body, b...)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MetricsResponseWriter) Header() http.Header {
|
||||||
|
return w.ResponseWriter.Header()
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||||
|
type TokenMetrics struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
|
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||||
|
DurationMs int `json:"duration_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenMetricsEvent represents a token metrics event
|
||||||
|
type TokenMetricsEvent struct {
|
||||||
|
Metrics TokenMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e TokenMetricsEvent) Type() uint32 {
|
||||||
|
return TokenMetricsEventID // defined in events.go
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsMonitor parses llama-server output for token statistics
|
||||||
|
type MetricsMonitor struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
metrics []TokenMetrics
|
||||||
|
maxMetrics int
|
||||||
|
nextID int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMetricsMonitor(config *Config) *MetricsMonitor {
|
||||||
|
maxMetrics := config.MetricsMaxInMemory
|
||||||
|
if maxMetrics <= 0 {
|
||||||
|
maxMetrics = 1000 // Default fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := &MetricsMonitor{
|
||||||
|
maxMetrics: maxMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
// addMetrics adds a new metric to the collection and publishes an event
|
||||||
|
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
metric.ID = mp.nextID
|
||||||
|
mp.nextID++
|
||||||
|
mp.metrics = append(mp.metrics, metric)
|
||||||
|
if len(mp.metrics) > mp.maxMetrics {
|
||||||
|
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||||
|
}
|
||||||
|
|
||||||
|
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns a copy of the current metrics
|
||||||
|
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]TokenMetrics, len(mp.metrics))
|
||||||
|
copy(result, mp.metrics)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetricsJSON returns metrics as JSON
|
||||||
|
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
return json.Marshal(mp.metrics)
|
||||||
|
}
|
||||||
@@ -5,13 +5,17 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProcessState string
|
type ProcessState string
|
||||||
@@ -22,18 +26,30 @@ const (
|
|||||||
StateReady ProcessState = ProcessState("ready")
|
StateReady ProcessState = ProcessState("ready")
|
||||||
StateStopping ProcessState = ProcessState("stopping")
|
StateStopping ProcessState = ProcessState("stopping")
|
||||||
|
|
||||||
// failed a health check on start and will not be recovered
|
|
||||||
StateFailed ProcessState = ProcessState("failed")
|
|
||||||
|
|
||||||
// process is shutdown and will not be restarted
|
// process is shutdown and will not be restarted
|
||||||
StateShutdown ProcessState = ProcessState("shutdown")
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type StopStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StopImmediately StopStrategy = iota
|
||||||
|
StopWaitForInflightRequest
|
||||||
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
logMonitor *LogMonitor
|
|
||||||
|
// PR #155 called to cancel the upstream process
|
||||||
|
cancelUpstream context.CancelFunc
|
||||||
|
|
||||||
|
// closed when command exits
|
||||||
|
cmdWaitChan chan struct{}
|
||||||
|
|
||||||
|
processLogger *LogMonitor
|
||||||
|
proxyLogger *LogMonitor
|
||||||
|
|
||||||
healthCheckTimeout int
|
healthCheckTimeout int
|
||||||
healthCheckLoopInterval time.Duration
|
healthCheckLoopInterval time.Duration
|
||||||
@@ -48,26 +64,48 @@ type Process struct {
|
|||||||
// used to block on multiple start() calls
|
// used to block on multiple start() calls
|
||||||
waitStarting sync.WaitGroup
|
waitStarting sync.WaitGroup
|
||||||
|
|
||||||
// for managing shutdown state
|
// for managing concurrency limits
|
||||||
shutdownCtx context.Context
|
concurrencyLimitSemaphore chan struct{}
|
||||||
shutdownCancel context.CancelFunc
|
|
||||||
|
// used for testing to override the default value
|
||||||
|
gracefulStopTimeout time.Duration
|
||||||
|
|
||||||
|
// track the number of failed starts
|
||||||
|
failedStartCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
concurrentLimit := 10
|
||||||
|
if config.ConcurrencyLimit > 0 {
|
||||||
|
concurrentLimit = config.ConcurrencyLimit
|
||||||
|
}
|
||||||
|
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
logMonitor: logMonitor,
|
cancelUpstream: nil,
|
||||||
|
processLogger: processLogger,
|
||||||
|
proxyLogger: proxyLogger,
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||||
state: StateStopped,
|
state: StateStopped,
|
||||||
shutdownCtx: ctx,
|
|
||||||
shutdownCancel: cancel,
|
// concurrency limit
|
||||||
|
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||||
|
|
||||||
|
// To be removed when migration over exec.CommandContext is complete
|
||||||
|
// stop timeout
|
||||||
|
gracefulStopTimeout: 10 * time.Second,
|
||||||
|
cmdWaitChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogMonitor returns the log monitor associated with the process.
|
||||||
|
func (p *Process) LogMonitor() *LogMonitor {
|
||||||
|
return p.processLogger
|
||||||
|
}
|
||||||
|
|
||||||
// custom error types for swapping state
|
// custom error types for swapping state
|
||||||
var (
|
var (
|
||||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||||
@@ -81,14 +119,18 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
|||||||
defer p.stateMutex.Unlock()
|
defer p.stateMutex.Unlock()
|
||||||
|
|
||||||
if p.state != expectedState {
|
if p.state != expectedState {
|
||||||
|
p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState)
|
||||||
return p.state, ErrExpectedStateMismatch
|
return p.state, ErrExpectedStateMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isValidTransition(p.state, newState) {
|
if !isValidTransition(p.state, newState) {
|
||||||
|
p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState)
|
||||||
return p.state, ErrInvalidStateTransition
|
return p.state, ErrInvalidStateTransition
|
||||||
}
|
}
|
||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
|
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||||
|
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||||
return p.state, nil
|
return p.state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,12 +140,12 @@ func isValidTransition(from, to ProcessState) bool {
|
|||||||
case StateStopped:
|
case StateStopped:
|
||||||
return to == StateStarting
|
return to == StateStarting
|
||||||
case StateStarting:
|
case StateStarting:
|
||||||
return to == StateReady || to == StateFailed || to == StateStopping
|
return to == StateReady || to == StateStopping || to == StateStopped
|
||||||
case StateReady:
|
case StateReady:
|
||||||
return to == StateStopping
|
return to == StateStopping
|
||||||
case StateStopping:
|
case StateStopping:
|
||||||
return to == StateStopped || to == StateShutdown
|
return to == StateStopped || to == StateShutdown
|
||||||
case StateFailed, StateShutdown:
|
case StateShutdown:
|
||||||
return false // No transitions allowed from these states
|
return false // No transitions allowed from these states
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -150,25 +192,37 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
p.waitStarting.Add(1)
|
p.waitStarting.Add(1)
|
||||||
defer p.waitStarting.Done()
|
defer p.waitStarting.Done()
|
||||||
|
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||||
|
|
||||||
p.cmd = exec.Command(args[0], args[1:]...)
|
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||||
p.cmd.Stdout = p.logMonitor
|
p.cmd.Stdout = p.processLogger
|
||||||
p.cmd.Stderr = p.logMonitor
|
p.cmd.Stderr = p.processLogger
|
||||||
p.cmd.Env = p.config.Env
|
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||||
|
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||||
|
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||||
|
p.cancelUpstream = ctxCancelUpstream
|
||||||
|
p.cmdWaitChan = make(chan struct{})
|
||||||
|
|
||||||
|
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
|
|
||||||
// Set process state to failed
|
// Set process state to failed
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if curState, swapErr := p.swapState(StateStarting, StateFailed); err != nil {
|
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||||
|
p.state = StateStopped // force it into a stopped state
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||||
err, curState, swapErr,
|
strings.Join(args, " "), err, curState, swapErr,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("start() failed: %v", err)
|
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture the exit error for later signalling
|
||||||
|
go p.waitForCmd()
|
||||||
|
|
||||||
// One of three things can happen at this stage:
|
// One of three things can happen at this stage:
|
||||||
// 1. The command exits unexpectedly
|
// 1. The command exits unexpectedly
|
||||||
// 2. The health check fails
|
// 2. The health check fails
|
||||||
@@ -183,50 +237,38 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||||
if checkEndpoint != "none" {
|
if checkEndpoint != "none" {
|
||||||
// keep default behaviour
|
|
||||||
if checkEndpoint == "" {
|
|
||||||
checkEndpoint = "/health"
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
proxyTo := p.config.Proxy
|
||||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
|
||||||
context.Background(),
|
|
||||||
checkStartTime.Add(maxDuration),
|
|
||||||
)
|
|
||||||
defer cancelHealthCheck()
|
|
||||||
|
|
||||||
loop:
|
|
||||||
// Ready Check loop
|
// Ready Check loop
|
||||||
for {
|
for {
|
||||||
select {
|
currentState := p.CurrentState()
|
||||||
case <-checkDeadline.Done():
|
if currentState != StateStarting {
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
if currentState == StateStopped {
|
||||||
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||||
} else {
|
|
||||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
|
||||||
}
|
}
|
||||||
case <-p.shutdownCtx.Done():
|
|
||||||
return errors.New("health check interrupted due to shutdown")
|
return errors.New("health check interrupted due to shutdown")
|
||||||
default:
|
|
||||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
|
||||||
cancelHealthCheck()
|
|
||||||
break loop
|
|
||||||
} else {
|
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
|
||||||
endTime, _ := checkDeadline.Deadline()
|
|
||||||
ttl := time.Until(endTime)
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if time.Since(checkStartTime) > maxDuration {
|
||||||
|
p.stopCommand()
|
||||||
|
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||||
|
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
|
ttl := time.Until(checkStartTime.Add(maxDuration))
|
||||||
|
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
<-time.After(p.healthCheckLoopInterval)
|
<-time.After(p.healthCheckLoopInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -246,7 +288,7 @@ func (p *Process) start() error {
|
|||||||
p.inFlightRequests.Wait()
|
p.inFlightRequests.Wait()
|
||||||
|
|
||||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
|
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||||
p.Stop()
|
p.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -257,83 +299,83 @@ func (p *Process) start() error {
|
|||||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||||
} else {
|
} else {
|
||||||
|
p.failedStartCount = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop will wait for inflight requests to complete before stopping the process.
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
// wait for any inflight requests before proceeding
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
p.inFlightRequests.Wait()
|
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
|
||||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the process with a graceful exit timeout
|
// wait for any inflight requests before proceeding
|
||||||
p.stopCommand(5 * time.Second)
|
p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID)
|
||||||
|
p.inFlightRequests.Wait()
|
||||||
|
p.StopImmediately()
|
||||||
|
}
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||||
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState)
|
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||||
|
func (p *Process) StopImmediately() {
|
||||||
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
|
||||||
|
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||||
|
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.stopCommand()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||||
// of time for any inflight requests to complete before shutting down. If the Process
|
// of time for any inflight requests to complete before shutting down. If the Process
|
||||||
// is in the state of starting, it will cancel it and shut it down
|
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||||
|
// the StateShutdown state, it can not be started again.
|
||||||
func (p *Process) Shutdown() {
|
func (p *Process) Shutdown() {
|
||||||
p.shutdownCancel()
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
p.stopCommand(5 * time.Second)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.stopCommand()
|
||||||
|
// just force it to this state since there is no recovery from shutdown
|
||||||
p.state = StateShutdown
|
p.state = StateShutdown
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
func (p *Process) stopCommand() {
|
||||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
stopStartTime := time.Now()
|
||||||
defer cancelTimeout()
|
defer func() {
|
||||||
|
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||||
sigtermNormal := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
sigtermNormal <- p.cmd.Wait()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if p.cmd == nil || p.cmd.Process == nil {
|
if p.cancelUpstream == nil {
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] cmd or cmd.Process is nil", p.ID)
|
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
p.cancelUpstream()
|
||||||
|
<-p.cmdWaitChan
|
||||||
select {
|
|
||||||
case <-sigtermTimeout.Done():
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] timed out waiting to stop, sending KILL signal\n", p.ID)
|
|
||||||
p.cmd.Process.Kill()
|
|
||||||
case err := <-sigtermNormal:
|
|
||||||
if err != nil {
|
|
||||||
if errno, ok := err.(syscall.Errno); ok {
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] errno >> %v\n", p.ID, errno)
|
|
||||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
|
||||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] stopped OK\n", p.ID)
|
|
||||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] interrupted OK\n", p.ID)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] ExitError >> %v, exit code: %d\n", p.ID, exitError, exitError.ExitCode())
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(p.logMonitor, "!!! process [%s] exited >> %v\n", p.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Timeout: 500 * time.Millisecond,
|
// wait a short time for a tcp connection to be established
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 500 * time.Millisecond,
|
||||||
|
}).DialContext,
|
||||||
|
},
|
||||||
|
|
||||||
|
// give a long time to respond to the health check endpoint
|
||||||
|
// after the connection is established. See issue: 276
|
||||||
|
Timeout: 5000 * time.Millisecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", healthURL, nil)
|
req, err := http.NewRequest("GET", healthURL, nil)
|
||||||
@@ -356,14 +398,24 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestBeginTime := time.Now()
|
||||||
|
var startDuration time.Duration
|
||||||
|
|
||||||
// prevent new requests from being made while stopping or irrecoverable
|
// prevent new requests from being made while stopping or irrecoverable
|
||||||
currentState := p.CurrentState()
|
currentState := p.CurrentState()
|
||||||
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
if currentState == StateShutdown || currentState == StateStopping {
|
||||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p.concurrencyLimitSemaphore <- struct{}{}:
|
||||||
|
defer func() { <-p.concurrencyLimitSemaphore }()
|
||||||
|
default:
|
||||||
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.inFlightRequests.Add(1)
|
p.inFlightRequests.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
p.lastRequestHandled = time.Now()
|
p.lastRequestHandled = time.Now()
|
||||||
@@ -372,11 +424,13 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// start the process on demand
|
// start the process on demand
|
||||||
if p.CurrentState() != StateReady {
|
if p.CurrentState() != StateReady {
|
||||||
|
beginStartTime := time.Now()
|
||||||
if err := p.start(); err != nil {
|
if err := p.start(); err != nil {
|
||||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||||
http.Error(w, errstr, http.StatusBadGateway)
|
http.Error(w, errstr, http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
startDuration = time.Since(beginStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
proxyTo := p.config.Proxy
|
||||||
@@ -387,6 +441,12 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header = r.Header.Clone()
|
req.Header = r.Header.Clone()
|
||||||
|
|
||||||
|
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
req.ContentLength = contentLength
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
@@ -420,4 +480,84 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalTime := time.Since(requestBeginTime)
|
||||||
|
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||||
|
p.ID, r.RequestURI, startDuration, totalTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
|
||||||
|
func (p *Process) waitForCmd() {
|
||||||
|
exitErr := p.cmd.Wait()
|
||||||
|
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||||
|
|
||||||
|
if exitErr != nil {
|
||||||
|
if errno, ok := exitErr.(syscall.Errno); ok {
|
||||||
|
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||||
|
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
|
||||||
|
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||||
|
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||||
|
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||||
|
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if exitErr.Error() != "context canceled" /* this is normal */ {
|
||||||
|
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := p.CurrentState()
|
||||||
|
switch currentState {
|
||||||
|
case StateStopping:
|
||||||
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||||
|
p.state = StateStopped
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||||
|
p.state = StateStopped // force it to be in this state
|
||||||
|
}
|
||||||
|
close(p.cmdWaitChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||||
|
func (p *Process) cmdStopUpstreamProcess() error {
|
||||||
|
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
|
||||||
|
|
||||||
|
// this should never happen ...
|
||||||
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
|
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||||
|
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.CmdStop != "" {
|
||||||
|
// replace ${PID} with the pid of the process
|
||||||
|
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||||
|
if err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||||
|
|
||||||
|
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||||
|
stopCmd.Stdout = p.processLogger
|
||||||
|
stopCmd.Stderr = p.processLogger
|
||||||
|
stopCmd.Env = p.cmd.Env
|
||||||
|
|
||||||
|
if err := stopCmd.Run(); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -13,13 +13,26 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
debugLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// flip to help with debugging tests
|
||||||
|
if false {
|
||||||
|
debugLogger.SetLogLevel(LevelDebug)
|
||||||
|
} else {
|
||||||
|
debugLogger.SetLogLevel(LevelError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
|
||||||
expectedMessage := "testing91931"
|
expectedMessage := "testing91931"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
// Create a process
|
// Create a process
|
||||||
process := NewProcess("test-process", 5, config, logMonitor)
|
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
@@ -52,11 +65,10 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
|||||||
// are all handled successfully, even though they all may ask for the process to .start()
|
// are all handled successfully, even though they all may ask for the process to .start()
|
||||||
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
||||||
|
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
|
||||||
expectedMessage := "testing91931"
|
expectedMessage := "testing91931"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
process := NewProcess("test-process", 5, config, logMonitor)
|
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@@ -84,7 +96,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -94,8 +106,8 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
|
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||||
@@ -109,7 +121,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
config.UnloadAfter = 3 // seconds
|
config.UnloadAfter = 3 // seconds
|
||||||
assert.Equal(t, 3, config.UnloadAfter)
|
assert.Equal(t, 3, config.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
|
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
// this should take 4 seconds
|
// this should take 4 seconds
|
||||||
@@ -151,7 +163,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
|||||||
config.UnloadAfter = 1 // second
|
config.UnloadAfter = 1 // second
|
||||||
assert.Equal(t, 1, config.UnloadAfter)
|
assert.Equal(t, 1, config.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
|
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
@@ -178,7 +190,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
|
|||||||
|
|
||||||
expectedMessage := "12345"
|
expectedMessage := "12345"
|
||||||
config := getTestSimpleResponderConfig(expectedMessage)
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
|
process := NewProcess("t", 10, config, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
results := map[string]string{
|
results := map[string]string{
|
||||||
@@ -236,18 +248,14 @@ func TestProcess_SwapState(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||||
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
|
||||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||||
|
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
|
||||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||||
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
|
||||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||||
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
|
||||||
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
|
||||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||||
@@ -255,9 +263,8 @@ func TestProcess_SwapState(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
p := &Process{
|
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
|
||||||
state: test.currentState,
|
p.state = test.currentState
|
||||||
}
|
|
||||||
|
|
||||||
resultState, err := p.swapState(test.expectedState, test.newState)
|
resultState, err := p.swapState(test.expectedState, test.newState)
|
||||||
if err != nil && test.expectedError == nil {
|
if err != nil && test.expectedError == nil {
|
||||||
@@ -282,7 +289,6 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
|||||||
t.Skip("skipping long shutdown test")
|
t.Skip("skipping long shutdown test")
|
||||||
}
|
}
|
||||||
|
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
|
||||||
expectedMessage := "testing91931"
|
expectedMessage := "testing91931"
|
||||||
|
|
||||||
// make a config where the healthcheck will always fail because port is wrong
|
// make a config where the healthcheck will always fail because port is wrong
|
||||||
@@ -290,7 +296,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
|||||||
config.Proxy = "http://localhost:9998/test"
|
config.Proxy = "http://localhost:9998/test"
|
||||||
|
|
||||||
healthCheckTTLSeconds := 30
|
healthCheckTTLSeconds := 30
|
||||||
process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor)
|
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
|
||||||
|
|
||||||
// make it a lot faster
|
// make it a lot faster
|
||||||
process.healthCheckLoopInterval = time.Second
|
process.healthCheckLoopInterval = time.Second
|
||||||
@@ -311,3 +317,177 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
|
|||||||
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
|
||||||
assert.Equal(t, StateShutdown, process.CurrentState())
|
assert.Equal(t, StateShutdown, process.CurrentState())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping Exit Interrupts Health Check test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// should run and exit but interrupt the long checkHealthTimeout
|
||||||
|
checkHealthTimeout := 5
|
||||||
|
config := ModelConfig{
|
||||||
|
Cmd: "sleep 1",
|
||||||
|
Proxy: "http://127.0.0.1:9913",
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
|
||||||
|
process.healthCheckLoopInterval = time.Second // make it faster
|
||||||
|
err := process.start()
|
||||||
|
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping long concurrency limit test")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := "concurrency_limit_test"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
// only allow 1 concurrent request at a time
|
||||||
|
config.ConcurrencyLimit = 1
|
||||||
|
|
||||||
|
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||||
|
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
// launch a goroutine first to take up the semaphore
|
||||||
|
go func() {
|
||||||
|
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req1)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// let the goroutine start
|
||||||
|
<-time.After(time.Millisecond * 25)
|
||||||
|
|
||||||
|
denied := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, denied)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_StopImmediately(t *testing.T) {
|
||||||
|
expectedMessage := "test_stop_immediate"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
go func() {
|
||||||
|
// slow, but will get killed by StopImmediate
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
}()
|
||||||
|
<-time.After(time.Millisecond)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||||
|
// the upstream command
|
||||||
|
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("skipping SIGTERM test on Windows ")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := "test_sigkill"
|
||||||
|
binaryPath := getSimpleResponderPath()
|
||||||
|
port := getTestPort()
|
||||||
|
|
||||||
|
config := ModelConfig{
|
||||||
|
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||||
|
// to force the process to exit
|
||||||
|
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
// reduce to make testing go faster
|
||||||
|
process.gracefulStopTimeout = time.Second
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
|
||||||
|
waitChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
// slow, but will get killed by StopImmediate
|
||||||
|
req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
|
||||||
|
// StatusOK because that was already sent before the kill
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
// unexpected EOF because the kill happened, the "1" is sent before the kill
|
||||||
|
// then the unexpected EOF is sent after the kill
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||||
|
} else {
|
||||||
|
assert.Contains(t, w.Body.String(), "unexpected EOF")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(waitChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-time.After(time.Millisecond)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
|
||||||
|
// the request should have been interrupted by SIGKILL
|
||||||
|
<-waitChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_StopCmd(t *testing.T) {
|
||||||
|
config := getTestSimpleResponderConfig("test_stop_cmd")
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
} else {
|
||||||
|
config.CmdStop = "kill -TERM ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||||
|
expectedMessage := "test_env_not_emptied"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
// ensure that the the default config does not blank out the inherited environment
|
||||||
|
configWEnv := config
|
||||||
|
|
||||||
|
// ensure the additiona variables are appended to the process' environment
|
||||||
|
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||||
|
|
||||||
|
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
|
||||||
|
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||||
|
|
||||||
|
process1.start()
|
||||||
|
defer process1.Stop()
|
||||||
|
process2.start()
|
||||||
|
defer process2.Stop()
|
||||||
|
|
||||||
|
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||||
|
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||||
|
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,124 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProcessGroup struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
config Config
|
||||||
|
id string
|
||||||
|
swap bool
|
||||||
|
exclusive bool
|
||||||
|
persistent bool
|
||||||
|
|
||||||
|
proxyLogger *LogMonitor
|
||||||
|
upstreamLogger *LogMonitor
|
||||||
|
|
||||||
|
// map of current processes
|
||||||
|
processes map[string]*Process
|
||||||
|
lastUsedProcess string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||||
|
groupConfig, ok := config.Groups[id]
|
||||||
|
if !ok {
|
||||||
|
panic("Unable to find configuration for group id: " + id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pg := &ProcessGroup{
|
||||||
|
id: id,
|
||||||
|
config: config,
|
||||||
|
swap: groupConfig.Swap,
|
||||||
|
exclusive: groupConfig.Exclusive,
|
||||||
|
persistent: groupConfig.Persistent,
|
||||||
|
proxyLogger: proxyLogger,
|
||||||
|
upstreamLogger: upstreamLogger,
|
||||||
|
processes: make(map[string]*Process),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a Process for each member in the group
|
||||||
|
for _, modelID := range groupConfig.Members {
|
||||||
|
modelConfig, modelID, _ := pg.config.FindConfig(modelID)
|
||||||
|
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger)
|
||||||
|
pg.processes[modelID] = process
|
||||||
|
}
|
||||||
|
|
||||||
|
return pg
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyRequest proxies a request to the specified model
|
||||||
|
func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error {
|
||||||
|
if !pg.HasMember(modelID) {
|
||||||
|
return fmt.Errorf("model %s not part of group %s", modelID, pg.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pg.swap {
|
||||||
|
pg.Lock()
|
||||||
|
if pg.lastUsedProcess != modelID {
|
||||||
|
|
||||||
|
// is there something already running?
|
||||||
|
if pg.lastUsedProcess != "" {
|
||||||
|
pg.processes[pg.lastUsedProcess].Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for the request to the new model to be fully handled
|
||||||
|
// and prevent race conditions see issue #277
|
||||||
|
pg.processes[modelID].ProxyRequest(writer, request)
|
||||||
|
pg.lastUsedProcess = modelID
|
||||||
|
|
||||||
|
// short circuit and exit
|
||||||
|
pg.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pg.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
pg.processes[modelID].ProxyRequest(writer, request)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||||
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||||
|
pg.Lock()
|
||||||
|
defer pg.Unlock()
|
||||||
|
|
||||||
|
if len(pg.processes) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop Processes in parallel
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range pg.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(process *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) Shutdown() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, process := range pg.processes {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(process *Process) {
|
||||||
|
defer wg.Done()
|
||||||
|
process.Shutdown()
|
||||||
|
}(process)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
|
"model4": getTestSimpleResponderConfig("model4"),
|
||||||
|
"model5": getTestSimpleResponderConfig("model5"),
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
"G2": {
|
||||||
|
Swap: false,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model3", "model4"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||||
|
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||||
|
assert.True(t, pg.HasMember("model5"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessGroup_HasMember(t *testing.T) {
|
||||||
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
|
assert.True(t, pg.HasMember("model1"))
|
||||||
|
assert.True(t, pg.HasMember("model2"))
|
||||||
|
assert.False(t, pg.HasMember("model3"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||||
|
// and multiple requests are made in parallel, only one process is running at a time.
|
||||||
|
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||||
|
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
// use the same listening so if a model is already running, it will fail
|
||||||
|
// this is a way to test that swap isolation is working
|
||||||
|
// properly when there are parallel requests made at the
|
||||||
|
// same time.
|
||||||
|
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||||
|
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||||
|
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||||
|
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||||
|
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Add(len(tests))
|
||||||
|
for _, modelName := range tests {
|
||||||
|
go func(modelName string) {
|
||||||
|
defer wg.Done()
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
}(modelName)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||||
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
tests := []string{"model3", "model4"}
|
||||||
|
|
||||||
|
for _, modelName := range tests {
|
||||||
|
t.Run(modelName, func(t *testing.T) {
|
||||||
|
reqBody := `{"x", "y"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure all the processes are running
|
||||||
|
for _, process := range pg.processes {
|
||||||
|
assert.Equal(t, StateReady, process.CurrentState())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,11 +2,12 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -14,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -25,74 +27,184 @@ const (
|
|||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config *Config
|
config Config
|
||||||
currentProcesses map[string]*Process
|
ginEngine *gin.Engine
|
||||||
logMonitor *LogMonitor
|
|
||||||
ginEngine *gin.Engine
|
// logging
|
||||||
|
proxyLogger *LogMonitor
|
||||||
|
upstreamLogger *LogMonitor
|
||||||
|
muxLogger *LogMonitor
|
||||||
|
|
||||||
|
metricsMonitor *MetricsMonitor
|
||||||
|
|
||||||
|
processGroups map[string]*ProcessGroup
|
||||||
|
|
||||||
|
// shutdown signaling
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config *Config) *ProxyManager {
|
func New(config Config) *ProxyManager {
|
||||||
pm := &ProxyManager{
|
// set up loggers
|
||||||
config: config,
|
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||||
currentProcesses: make(map[string]*Process),
|
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||||
logMonitor: NewLogMonitor(),
|
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
||||||
ginEngine: gin.New(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.LogRequests {
|
if config.LogRequests {
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
||||||
// Start timer
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
// capture these because /upstream/:model rewrites them in c.Next()
|
|
||||||
clientIP := c.ClientIP()
|
|
||||||
method := c.Request.Method
|
|
||||||
path := c.Request.URL.Path
|
|
||||||
|
|
||||||
// Process request
|
|
||||||
c.Next()
|
|
||||||
|
|
||||||
// Stop timer
|
|
||||||
duration := time.Since(start)
|
|
||||||
|
|
||||||
statusCode := c.Writer.Status()
|
|
||||||
bodySize := c.Writer.Size()
|
|
||||||
|
|
||||||
fmt.Fprintf(pm.logMonitor, "[llama-swap] %s [%s] \"%s %s %s\" %d %d \"%s\" %v\n",
|
|
||||||
clientIP,
|
|
||||||
time.Now().Format("2006-01-02 15:04:05"),
|
|
||||||
method,
|
|
||||||
path,
|
|
||||||
c.Request.Proto,
|
|
||||||
statusCode,
|
|
||||||
bodySize,
|
|
||||||
c.Request.UserAgent(),
|
|
||||||
duration,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// see: https://github.com/mostlygeek/llama-swap/issues/42
|
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
||||||
|
case "debug":
|
||||||
|
proxyLogger.SetLogLevel(LevelDebug)
|
||||||
|
upstreamLogger.SetLogLevel(LevelDebug)
|
||||||
|
case "info":
|
||||||
|
proxyLogger.SetLogLevel(LevelInfo)
|
||||||
|
upstreamLogger.SetLogLevel(LevelInfo)
|
||||||
|
case "warn":
|
||||||
|
proxyLogger.SetLogLevel(LevelWarn)
|
||||||
|
upstreamLogger.SetLogLevel(LevelWarn)
|
||||||
|
case "error":
|
||||||
|
proxyLogger.SetLogLevel(LevelError)
|
||||||
|
upstreamLogger.SetLogLevel(LevelError)
|
||||||
|
default:
|
||||||
|
proxyLogger.SetLogLevel(LevelInfo)
|
||||||
|
upstreamLogger.SetLogLevel(LevelInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
pm := &ProxyManager{
|
||||||
|
config: config,
|
||||||
|
ginEngine: gin.New(),
|
||||||
|
|
||||||
|
proxyLogger: proxyLogger,
|
||||||
|
muxLogger: stdoutLogger,
|
||||||
|
upstreamLogger: upstreamLogger,
|
||||||
|
|
||||||
|
metricsMonitor: NewMetricsMonitor(&config),
|
||||||
|
|
||||||
|
processGroups: make(map[string]*ProcessGroup),
|
||||||
|
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownCancel: shutdownCancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the process groups
|
||||||
|
for groupID := range config.Groups {
|
||||||
|
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
||||||
|
pm.processGroups[groupID] = processGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.setupGinEngine()
|
||||||
|
|
||||||
|
// run any startup hooks
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
// do it in the background, don't block startup -- not sure if good idea yet
|
||||||
|
go func() {
|
||||||
|
discardWriter := &DiscardWriter{}
|
||||||
|
for _, realModelName := range config.Hooks.OnStartup.Preload {
|
||||||
|
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||||
|
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
event.Emit(ModelPreloadedEvent{
|
||||||
|
ModelName: realModelName,
|
||||||
|
Success: false,
|
||||||
|
})
|
||||||
|
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
req, _ := http.NewRequest("GET", "/", nil)
|
||||||
|
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||||
|
event.Emit(ModelPreloadedEvent{
|
||||||
|
ModelName: realModelName,
|
||||||
|
Success: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return pm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) setupGinEngine() {
|
||||||
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
|
// Start timer
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// capture these because /upstream/:model rewrites them in c.Next()
|
||||||
|
clientIP := c.ClientIP()
|
||||||
|
method := c.Request.Method
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
|
// Process request
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
// Stop timer
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
statusCode := c.Writer.Status()
|
||||||
|
bodySize := c.Writer.Size()
|
||||||
|
|
||||||
|
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
||||||
|
clientIP,
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
c.Request.Proto,
|
||||||
|
statusCode,
|
||||||
|
bodySize,
|
||||||
|
c.Request.UserAgent(),
|
||||||
|
duration,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// see: issue: #81, #77 and #42 for CORS issues
|
||||||
// respond with permissive OPTIONS for any endpoint
|
// respond with permissive OPTIONS for any endpoint
|
||||||
pm.ginEngine.Use(func(c *gin.Context) {
|
pm.ginEngine.Use(func(c *gin.Context) {
|
||||||
if c.Request.Method == "OPTIONS" {
|
if c.Request.Method == "OPTIONS" {
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
c.Header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
|
||||||
c.AbortWithStatus(204)
|
// allow whatever the client requested by default
|
||||||
|
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||||
|
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
||||||
|
c.Header("Access-Control-Allow-Headers", sanitized)
|
||||||
|
} else {
|
||||||
|
c.Header(
|
||||||
|
"Access-Control-Allow-Headers",
|
||||||
|
"Content-Type, Authorization, Accept, X-Requested-With",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
c.Header("Access-Control-Max-Age", "86400")
|
||||||
|
c.AbortWithStatus(http.StatusNoContent)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
})
|
})
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
mm := MetricsMiddleware(pm)
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
|
||||||
// Support legacy /v1/completions api, see issue #12
|
|
||||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
|
||||||
|
|
||||||
// Support embeddings
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
// Support legacy /v1/completions api, see issue #12
|
||||||
|
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// Support embeddings and reranking
|
||||||
|
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// llama-server's /reranking endpoint + aliases
|
||||||
|
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler)
|
||||||
|
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
|
||||||
|
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
|
||||||
|
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// llama-server's /infill endpoint for code infilling
|
||||||
|
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// llama-server's /completion endpoint
|
||||||
|
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
@@ -103,198 +215,181 @@ func New(config *Config) *ProxyManager {
|
|||||||
// in proxymanager_loghandlers.go
|
// in proxymanager_loghandlers.go
|
||||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
/**
|
||||||
|
* User Interface Endpoints
|
||||||
|
*/
|
||||||
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||||
|
c.Redirect(http.StatusFound, "/ui")
|
||||||
|
})
|
||||||
|
|
||||||
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||||
|
c.Redirect(http.StatusFound, "/ui/models")
|
||||||
|
})
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||||
|
|
||||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||||
|
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
c.String(http.StatusOK, "OK")
|
||||||
// Set the Content-Type header to text/html
|
|
||||||
c.Header("Content-Type", "text/html")
|
|
||||||
|
|
||||||
// Write the embedded HTML content to the response
|
|
||||||
htmlData, err := getHTMLFile("index.html")
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = c.Writer.Write(htmlData)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||||
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
||||||
c.Data(http.StatusOK, "image/x-icon", data)
|
c.Data(http.StatusOK, "image/x-icon", data)
|
||||||
} else {
|
} else {
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
reactFS, err := GetReactFS()
|
||||||
|
if err != nil {
|
||||||
|
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
||||||
|
} else {
|
||||||
|
|
||||||
|
// serve files that exist under /ui/*
|
||||||
|
pm.ginEngine.StaticFS("/ui", reactFS)
|
||||||
|
|
||||||
|
// server SPA for UI under /ui/*
|
||||||
|
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
||||||
|
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
||||||
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := reactFS.Open("index.html")
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// see: proxymanager_api.go
|
||||||
|
// add API handler functions
|
||||||
|
addApiHandlers(pm)
|
||||||
|
|
||||||
// Disable console color for testing
|
// Disable console color for testing
|
||||||
gin.DisableConsoleColor()
|
gin.DisableConsoleColor()
|
||||||
|
|
||||||
return pm
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) Run(addr ...string) error {
|
// ServeHTTP implements http.Handler interface
|
||||||
return pm.ginEngine.Run(addr...)
|
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
|
||||||
pm.ginEngine.ServeHTTP(w, r)
|
pm.ginEngine.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) StopProcesses() {
|
// StopProcesses acquires a lock and stops all running upstream processes.
|
||||||
|
// This is the public method safe for concurrent calls.
|
||||||
|
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
||||||
|
// a complete shutdown, allowing for process replacement without full termination.
|
||||||
|
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
pm.stopProcesses()
|
|
||||||
}
|
|
||||||
|
|
||||||
// for internal usage
|
|
||||||
func (pm *ProxyManager) stopProcesses() {
|
|
||||||
if len(pm.currentProcesses) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// stop Processes in parallel
|
// stop Processes in parallel
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, process := range pm.currentProcesses {
|
for _, processGroup := range pm.processGroups {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(process *Process) {
|
go func(processGroup *ProcessGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
process.Stop()
|
processGroup.StopProcesses(strategy)
|
||||||
}(process)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
pm.currentProcesses = make(map[string]*Process)
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is called to shutdown all upstream processes
|
// Shutdown stops all processes managed by this ProxyManager
|
||||||
// when llama-swap is shutting down.
|
|
||||||
func (pm *ProxyManager) Shutdown() {
|
func (pm *ProxyManager) Shutdown() {
|
||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// shutdown process in parallel
|
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, process := range pm.currentProcesses {
|
// Send shutdown signal to all process in groups
|
||||||
|
for _, processGroup := range pm.processGroups {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(process *Process) {
|
go func(processGroup *ProcessGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
process.Shutdown()
|
processGroup.Shutdown()
|
||||||
}(process)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
pm.shutdownCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||||
|
// de-alias the real model name and get a real one
|
||||||
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
|
if processGroup == nil {
|
||||||
|
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if processGroup.exclusive {
|
||||||
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
||||||
|
for groupId, otherGroup := range pm.processGroups {
|
||||||
|
if groupId != processGroup.id && !otherGroup.persistent {
|
||||||
|
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return processGroup, realModelName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := []interface{}{}
|
data := make([]gin.H, 0, len(pm.config.Models))
|
||||||
|
createdTime := time.Now().Unix()
|
||||||
|
|
||||||
for id, modelConfig := range pm.config.Models {
|
for id, modelConfig := range pm.config.Models {
|
||||||
if modelConfig.Unlisted {
|
if modelConfig.Unlisted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
data = append(data, map[string]interface{}{
|
record := gin.H{
|
||||||
"id": id,
|
"id": id,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": time.Now().Unix(),
|
"created": createdTime,
|
||||||
"owned_by": "llama-swap",
|
"owned_by": "llama-swap",
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||||
|
record["name"] = name
|
||||||
|
}
|
||||||
|
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||||
|
record["description"] = desc
|
||||||
|
}
|
||||||
|
|
||||||
|
data = append(data, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the Content-Type header to application/json
|
// Sort by the "id" key
|
||||||
c.Header("Content-Type", "application/json")
|
sort.Slice(data, func(i, j int) bool {
|
||||||
|
si, _ := data[i]["id"].(string)
|
||||||
|
sj, _ := data[j]["id"].(string)
|
||||||
|
return si < sj
|
||||||
|
})
|
||||||
|
|
||||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
// Set CORS headers if origin exists
|
||||||
|
if origin := c.GetHeader("Origin"); origin != "" {
|
||||||
c.Header("Access-Control-Allow-Origin", origin)
|
c.Header("Access-Control-Allow-Origin", origin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode the data as JSON and write it to the response writer
|
// Use gin's JSON method which handles content-type and encoding
|
||||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
c.JSON(http.StatusOK, gin.H{
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
"object": "list",
|
||||||
return
|
"data": data,
|
||||||
}
|
})
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|
||||||
pm.Lock()
|
|
||||||
defer pm.Unlock()
|
|
||||||
|
|
||||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
|
||||||
|
|
||||||
if profileName != "" {
|
|
||||||
if _, found := pm.config.Profiles[profileName]; !found {
|
|
||||||
return nil, fmt.Errorf("model group not found %s", profileName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// de-alias the real model name and get a real one
|
|
||||||
realModelName, found := pm.config.RealModelName(modelName)
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if model is part of the profile
|
|
||||||
if profileName != "" {
|
|
||||||
found := false
|
|
||||||
for _, item := range pm.config.Profiles[profileName] {
|
|
||||||
if item == realModelName {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// exit early when already running, otherwise stop everything and swap
|
|
||||||
requestedProcessKey := ProcessKeyName(profileName, realModelName)
|
|
||||||
|
|
||||||
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
|
||||||
return process, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// stop all running models
|
|
||||||
pm.stopProcesses()
|
|
||||||
|
|
||||||
if profileName == "" {
|
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
|
||||||
processKey := ProcessKeyName(profileName, modelID)
|
|
||||||
pm.currentProcesses[processKey] = process
|
|
||||||
} else {
|
|
||||||
for _, modelName := range pm.config.Profiles[profileName] {
|
|
||||||
if realModelName, found := pm.config.RealModelName(modelName); found {
|
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
|
|
||||||
}
|
|
||||||
|
|
||||||
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
|
||||||
processKey := ProcessKeyName(profileName, modelID)
|
|
||||||
pm.currentProcesses[processKey] = process
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// requestedProcessKey should exist due to swap
|
|
||||||
return pm.currentProcesses[requestedProcessKey], nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||||
@@ -305,38 +400,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
if err != nil {
|
||||||
} else {
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
// rewrite the path
|
return
|
||||||
c.Request.URL.Path = c.Param("upstreamPath")
|
|
||||||
process.ProxyRequest(c.Writer, c.Request)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
// rewrite the path
|
||||||
var html strings.Builder
|
c.Request.URL.Path = c.Param("upstreamPath")
|
||||||
|
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
||||||
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
|
|
||||||
|
|
||||||
// Extract keys and sort them
|
|
||||||
var modelIDs []string
|
|
||||||
for modelID, modelConfig := range pm.config.Models {
|
|
||||||
if modelConfig.Unlisted {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
modelIDs = append(modelIDs, modelID)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIDs)
|
|
||||||
|
|
||||||
// Iterate over sorted keys
|
|
||||||
for _, modelID := range modelIDs {
|
|
||||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
|
|
||||||
}
|
|
||||||
html.WriteString("</ul></body></html>")
|
|
||||||
c.Header("Content-Type", "text/html")
|
|
||||||
c.String(http.StatusOK, html.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||||
@@ -349,50 +421,61 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
if requestedModel == "" {
|
if requestedModel == "" {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
process, err := pm.swapModel(requestedModel)
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// issue #69 allow custom model names to be sent to upstream
|
// issue #69 allow custom model names to be sent to upstream
|
||||||
if process.config.UseModelName != "" {
|
useModelName := pm.config.Models[realModelName].UseModelName
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
if useModelName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue #174 strip parameters from the JSON body
|
||||||
|
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
||||||
|
if err != nil { // just log it and continue
|
||||||
|
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
||||||
} else {
|
} else {
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
for _, param := range stripParams {
|
||||||
if profileName != "" {
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
||||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
c.Request.Header.Del("transfer-encoding")
|
c.Request.Header.Del("transfer-encoding")
|
||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
c.Request.ContentLength = int64(len(bodyBytes))
|
||||||
process.ProxyRequest(c.Writer, c.Request)
|
|
||||||
|
|
||||||
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
||||||
// Create a new buffer for the reconstructed request
|
|
||||||
var requestBuffer bytes.Buffer
|
|
||||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
||||||
|
|
||||||
// Parse multipart form
|
// Parse multipart form
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
@@ -406,15 +489,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Swap to the requested model
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||||
process, err := pm.swapModel(requestedModel)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get profile name and model name from the requested model
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
profileName, modelName := splitRequestedModel(requestedModel)
|
// Create a new buffer for the reconstructed request
|
||||||
|
var requestBuffer bytes.Buffer
|
||||||
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||||
|
|
||||||
// Copy all form values
|
// Copy all form values
|
||||||
for key, values := range c.Request.MultipartForm.Value {
|
for key, values := range c.Request.MultipartForm.Value {
|
||||||
@@ -422,10 +506,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
fieldValue := value
|
fieldValue := value
|
||||||
// If this is the model field and we have a profile, use just the model name
|
// If this is the model field and we have a profile, use just the model name
|
||||||
if key == "model" {
|
if key == "model" {
|
||||||
if process.config.UseModelName != "" {
|
// # issue #69 allow custom model names to be sent to upstream
|
||||||
fieldValue = process.config.UseModelName
|
useModelName := pm.config.Models[realModelName].UseModelName
|
||||||
} else if profileName != "" {
|
|
||||||
fieldValue = modelName
|
if useModelName != "" {
|
||||||
|
fieldValue = useModelName
|
||||||
|
} else {
|
||||||
|
fieldValue = requestedModel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
field, err := multipartWriter.CreateFormField(key)
|
field, err := multipartWriter.CreateFormField(key)
|
||||||
@@ -486,8 +573,16 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
modifiedReq.Header = c.Request.Header.Clone()
|
modifiedReq.Header = c.Request.Header.Clone()
|
||||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||||
|
|
||||||
|
// set the content length of the body
|
||||||
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||||
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||||
|
|
||||||
// Use the modified request for proxying
|
// Use the modified request for proxying
|
||||||
process.ProxyRequest(c.Writer, modifiedReq)
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||||
@@ -501,7 +596,7 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||||
pm.StopProcesses()
|
pm.StopProcesses(StopImmediately)
|
||||||
c.String(http.StatusOK, "OK")
|
c.String(http.StatusOK, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -509,14 +604,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|||||||
context.Header("Content-Type", "application/json")
|
context.Header("Content-Type", "application/json")
|
||||||
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
||||||
|
|
||||||
for _, process := range pm.currentProcesses {
|
for _, processGroup := range pm.processGroups {
|
||||||
|
for _, process := range processGroup.processes {
|
||||||
// Append the process ID and State (multiple entries if profiles are being used).
|
if process.CurrentState() == StateReady {
|
||||||
runningProcesses = append(runningProcesses, gin.H{
|
runningProcesses = append(runningProcesses, gin.H{
|
||||||
"model": process.ID,
|
"model": process.ID,
|
||||||
"state": process.state,
|
"state": process.state,
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put the results under the `running` key.
|
// Put the results under the `running` key.
|
||||||
@@ -527,15 +623,11 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|||||||
context.JSON(http.StatusOK, response) // Always return 200 OK
|
context.JSON(http.StatusOK, response) // Always return 200 OK
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProcessKeyName(groupName, modelName string) string {
|
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
||||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
for _, group := range pm.processGroups {
|
||||||
}
|
if group.HasMember(modelName) {
|
||||||
|
return group
|
||||||
func splitRequestedModel(requestedModel string) (string, string) {
|
}
|
||||||
profileName, modelName := "", requestedModel
|
|
||||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
|
||||||
profileName = requestedModel[:idx]
|
|
||||||
modelName = requestedModel[idx+1:]
|
|
||||||
}
|
}
|
||||||
return profileName, modelName
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,202 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Unlisted bool `json:"unlisted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func addApiHandlers(pm *ProxyManager) {
|
||||||
|
// Add API endpoints for React to consume
|
||||||
|
apiGroup := pm.ginEngine.Group("/api")
|
||||||
|
{
|
||||||
|
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||||
|
apiGroup.GET("/events", pm.apiSendEvents)
|
||||||
|
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) {
|
||||||
|
pm.StopProcesses(StopImmediately)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"msg": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) getModelStatus() []Model {
|
||||||
|
// Extract keys and sort them
|
||||||
|
models := []Model{}
|
||||||
|
|
||||||
|
modelIDs := make([]string, 0, len(pm.config.Models))
|
||||||
|
for modelID := range pm.config.Models {
|
||||||
|
modelIDs = append(modelIDs, modelID)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIDs)
|
||||||
|
|
||||||
|
// Iterate over sorted keys
|
||||||
|
for _, modelID := range modelIDs {
|
||||||
|
// Get process state
|
||||||
|
processGroup := pm.findGroupByModelName(modelID)
|
||||||
|
state := "unknown"
|
||||||
|
if processGroup != nil {
|
||||||
|
process := processGroup.processes[modelID]
|
||||||
|
if process != nil {
|
||||||
|
var stateStr string
|
||||||
|
switch process.CurrentState() {
|
||||||
|
case StateReady:
|
||||||
|
stateStr = "ready"
|
||||||
|
case StateStarting:
|
||||||
|
stateStr = "starting"
|
||||||
|
case StateStopping:
|
||||||
|
stateStr = "stopping"
|
||||||
|
case StateShutdown:
|
||||||
|
stateStr = "shutdown"
|
||||||
|
case StateStopped:
|
||||||
|
stateStr = "stopped"
|
||||||
|
default:
|
||||||
|
stateStr = "unknown"
|
||||||
|
}
|
||||||
|
state = stateStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
models = append(models, Model{
|
||||||
|
Id: modelID,
|
||||||
|
Name: pm.config.Models[modelID].Name,
|
||||||
|
Description: pm.config.Models[modelID].Description,
|
||||||
|
State: state,
|
||||||
|
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
msgTypeModelStatus messageType = "modelStatus"
|
||||||
|
msgTypeLogData messageType = "logData"
|
||||||
|
msgTypeMetrics messageType = "metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
type messageEnvelope struct {
|
||||||
|
Type messageType `json:"type"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// sends a stream of different message types that happen on the server
|
||||||
|
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
|
sendBuffer := make(chan messageEnvelope, 25)
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
sendModels := func() {
|
||||||
|
data, err := json.Marshal(pm.getModelStatus())
|
||||||
|
if err == nil {
|
||||||
|
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||||
|
select {
|
||||||
|
case sendBuffer <- msg:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendLogData := func(source string, data []byte) {
|
||||||
|
data, err := json.Marshal(gin.H{
|
||||||
|
"source": source,
|
||||||
|
"data": string(data),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendMetrics := func(metrics []TokenMetrics) {
|
||||||
|
jsonData, err := json.Marshal(metrics)
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send updated models list
|
||||||
|
*/
|
||||||
|
defer event.On(func(e ProcessStateChangeEvent) {
|
||||||
|
sendModels()
|
||||||
|
})()
|
||||||
|
defer event.On(func(e ConfigFileChangedEvent) {
|
||||||
|
sendModels()
|
||||||
|
})()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send Log data
|
||||||
|
*/
|
||||||
|
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||||
|
sendLogData("proxy", data)
|
||||||
|
})()
|
||||||
|
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||||
|
sendLogData("upstream", data)
|
||||||
|
})()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send Metrics data
|
||||||
|
*/
|
||||||
|
defer event.On(func(e TokenMetricsEvent) {
|
||||||
|
sendMetrics([]TokenMetrics{e.Metrics})
|
||||||
|
})()
|
||||||
|
|
||||||
|
// send initial batch of data
|
||||||
|
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||||
|
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||||
|
sendModels()
|
||||||
|
sendMetrics(pm.metricsMonitor.GetMetrics())
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
case <-pm.shutdownCtx.Done():
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
case msg := <-sendBuffer:
|
||||||
|
c.SSEvent("message", msg)
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||||
|
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", jsonData)
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -9,26 +10,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||||
|
|
||||||
accept := c.GetHeader("Accept")
|
accept := c.GetHeader("Accept")
|
||||||
if strings.Contains(accept, "text/html") {
|
if strings.Contains(accept, "text/html") {
|
||||||
// Set the Content-Type header to text/html
|
c.Redirect(http.StatusFound, "/ui/")
|
||||||
c.Header("Content-Type", "text/html")
|
|
||||||
|
|
||||||
// Write the embedded HTML content to the response
|
|
||||||
logsHTML, err := getHTMLFile("logs.html")
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = c.Writer.Write(logsHTML)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
history := pm.logMonitor.GetHistory()
|
history := pm.muxLogger.GetHistory()
|
||||||
_, err := c.Writer.Write(history)
|
_, err := c.Writer.Write(history)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, err)
|
c.AbortWithError(http.StatusInternalServerError, err)
|
||||||
@@ -42,10 +29,13 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
c.Header("Transfer-Encoding", "chunked")
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
ch := pm.logMonitor.Subscribe()
|
logMonitorId := c.Param("logMonitorID")
|
||||||
defer pm.logMonitor.Unsubscribe(ch)
|
logger, err := pm.getLogger(logMonitorId)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||||
@@ -56,58 +46,53 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
// Send history first if not skipped
|
// Send history first if not skipped
|
||||||
|
|
||||||
if !skipHistory {
|
if !skipHistory {
|
||||||
history := pm.logMonitor.GetHistory()
|
history := logger.GetHistory()
|
||||||
if len(history) != 0 {
|
if len(history) != 0 {
|
||||||
c.Writer.Write(history)
|
c.Writer.Write(history)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream new logs
|
sendChan := make(chan []byte, 10)
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
defer logger.OnLogData(func(data []byte) {
|
||||||
|
select {
|
||||||
|
case sendChan <- data:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-ch:
|
case <-c.Request.Context().Done():
|
||||||
_, err := c.Writer.Write(msg)
|
cancel()
|
||||||
if err != nil {
|
return
|
||||||
// just break the loop if we can't write for some reason
|
case <-pm.shutdownCtx.Done():
|
||||||
return
|
cancel()
|
||||||
}
|
return
|
||||||
|
case data := <-sendChan:
|
||||||
|
c.Writer.Write(data)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
case <-notify:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
// getLogger searches for the appropriate logger based on the logMonitorId
|
||||||
c.Header("Content-Type", "text/event-stream")
|
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
|
||||||
c.Header("Cache-Control", "no-cache")
|
var logger *LogMonitor
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
|
||||||
|
|
||||||
ch := pm.logMonitor.Subscribe()
|
if logMonitorId == "" {
|
||||||
defer pm.logMonitor.Unsubscribe(ch)
|
// maintain the default
|
||||||
|
logger = pm.muxLogger
|
||||||
notify := c.Request.Context().Done()
|
} else if logMonitorId == "proxy" {
|
||||||
|
logger = pm.proxyLogger
|
||||||
// Send history first if not skipped
|
} else if logMonitorId == "upstream" {
|
||||||
_, skipHistory := c.GetQuery("no-history")
|
logger = pm.upstreamLogger
|
||||||
if !skipHistory {
|
} else {
|
||||||
history := pm.logMonitor.GetHistory()
|
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'")
|
||||||
if len(history) != 0 {
|
|
||||||
c.SSEvent("message", string(history))
|
|
||||||
c.Writer.Flush()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream new logs
|
return logger, nil
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case msg := <-ch:
|
|
||||||
c.SSEvent("message", string(msg))
|
|
||||||
c.Writer.Flush()
|
|
||||||
case <-notify:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isTokenChar(r rune) bool {
|
||||||
|
switch {
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
case r >= 'A' && r <= 'Z':
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
case strings.ContainsRune("!#$%&'*+-.^_`|~", r):
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func SanitizeAccessControlRequestHeaderValues(headerValues string) string {
|
||||||
|
parts := strings.Split(headerValues, ",")
|
||||||
|
valid := make([]string, 0, len(parts))
|
||||||
|
|
||||||
|
for _, p := range parts {
|
||||||
|
v := strings.TrimSpace(p)
|
||||||
|
if v == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
validPart := true
|
||||||
|
for _, c := range v {
|
||||||
|
if !isTokenChar(c) {
|
||||||
|
validPart = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if validPart {
|
||||||
|
valid = append(valid, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(valid, ", ")
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace only",
|
||||||
|
input: " ",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single valid value",
|
||||||
|
input: "content-type",
|
||||||
|
expected: "content-type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid values",
|
||||||
|
input: "content-type, authorization, x-requested-with",
|
||||||
|
expected: "content-type, authorization, x-requested-with",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "values with extra spaces",
|
||||||
|
input: " content-type , authorization ",
|
||||||
|
expected: "content-type, authorization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "values with tabs",
|
||||||
|
input: "content-type,\tauthorization",
|
||||||
|
expected: "content-type, authorization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "values with invalid characters",
|
||||||
|
input: "content-type, auth\n, x-requested-with\r",
|
||||||
|
expected: "content-type, auth, x-requested-with",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty values in list",
|
||||||
|
input: "content-type,,authorization",
|
||||||
|
expected: "content-type, authorization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "leading and trailing commas",
|
||||||
|
input: ",content-type,authorization,",
|
||||||
|
expected: "content-type, authorization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed valid and invalid values",
|
||||||
|
input: "content-type, \x00invalid, x-requested-with",
|
||||||
|
expected: "content-type, x-requested-with",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case values",
|
||||||
|
input: "Content-Type, my-Valid-Header, Another-hEader",
|
||||||
|
expected: "Content-Type, my-Valid-Header, Another-hEader",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SanitizeAccessControlRequestHeaderValues(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q",
|
||||||
|
tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed ui_dist
|
||||||
|
var reactStaticFS embed.FS
|
||||||
|
|
||||||
|
// GetReactFS returns the embedded React filesystem
|
||||||
|
func GetReactFS() (http.FileSystem, error) {
|
||||||
|
subFS, err := fs.Sub(reactStaticFS, "ui_dist")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return http.FS(subFS), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReactIndexHTML returns the main index.html for the React app
|
||||||
|
func GetReactIndexHTML() ([]byte, error) {
|
||||||
|
return reactStaticFS.ReadFile("ui_dist/index.html")
|
||||||
|
}
|
||||||
@@ -0,0 +1,213 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
# This script installs llama-swap on Linux.
|
||||||
|
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
LLAMA_SWAP_DEFAULT_ADDRESS=${LLAMA_SWAP_DEFAULT_ADDRESS:-"127.0.0.1:8080"}
|
||||||
|
|
||||||
|
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||||
|
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||||
|
|
||||||
|
status() { echo ">>> $*" >&2; }
|
||||||
|
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||||
|
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||||
|
|
||||||
|
available() { command -v "$1" >/dev/null; }
|
||||||
|
require() {
|
||||||
|
_MISSING=''
|
||||||
|
for TOOL in "$@"; do
|
||||||
|
if ! available "$TOOL"; then
|
||||||
|
_MISSING="$_MISSING $TOOL"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "$_MISSING"
|
||||||
|
}
|
||||||
|
|
||||||
|
SUDO=
|
||||||
|
if [ "$(id -u)" -ne 0 ]; then
|
||||||
|
if ! available sudo; then
|
||||||
|
error "This script requires superuser permissions. Please re-run as root."
|
||||||
|
fi
|
||||||
|
|
||||||
|
SUDO="sudo"
|
||||||
|
fi
|
||||||
|
|
||||||
|
NEEDS=$(require tee tar python3 mktemp)
|
||||||
|
if [ -n "$NEEDS" ]; then
|
||||||
|
status "ERROR: The following tools are required but missing:"
|
||||||
|
for NEED in $NEEDS; do
|
||||||
|
echo " - $NEED"
|
||||||
|
done
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||||
|
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) ARCH="amd64" ;;
|
||||||
|
aarch64|arm64) ARCH="arm64" ;;
|
||||||
|
*) error "Unsupported architecture: $ARCH" ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
IS_WSL2=false
|
||||||
|
|
||||||
|
KERN=$(uname -r)
|
||||||
|
case "$KERN" in
|
||||||
|
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
|
||||||
|
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||||
|
*) ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
download_binary() {
|
||||||
|
ASSET_NAME="linux_$ARCH"
|
||||||
|
|
||||||
|
TMPDIR=$(mktemp -d)
|
||||||
|
trap 'rm -rf "${TMPDIR}"' EXIT INT TERM HUP
|
||||||
|
PYTHON_SCRIPT=$(cat <<EOF
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
ASSET_NAME = "${ASSET_NAME}"
|
||||||
|
|
||||||
|
with urllib.request.urlopen("https://api.github.com/repos/mostlygeek/llama-swap/releases/latest") as resp:
|
||||||
|
data = json.load(resp)
|
||||||
|
for asset in data.get("assets", []):
|
||||||
|
if ASSET_NAME in asset.get("name", ""):
|
||||||
|
url = asset["browser_download_url"]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("ERROR: Matching asset not found.", file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("Downloading:", url, file=sys.stderr)
|
||||||
|
output_path = os.path.join("${TMPDIR}", "llama-swap.tar.gz")
|
||||||
|
urllib.request.urlretrieve(url, output_path)
|
||||||
|
print(output_path)
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
|
||||||
|
TARFILE=$(python3 -c "$PYTHON_SCRIPT")
|
||||||
|
if [ ! -f "$TARFILE" ]; then
|
||||||
|
error "Failed to download binary."
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Extracting to /usr/local/bin"
|
||||||
|
$SUDO tar -xzf "$TARFILE" -C /usr/local/bin llama-swap
|
||||||
|
}
|
||||||
|
download_binary
|
||||||
|
|
||||||
|
configure_systemd() {
|
||||||
|
if ! id llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Creating llama-swap user..."
|
||||||
|
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
|
||||||
|
fi
|
||||||
|
if getent group render >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to render group..."
|
||||||
|
$SUDO usermod -a -G render llama-swap
|
||||||
|
fi
|
||||||
|
if getent group video >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to video group..."
|
||||||
|
$SUDO usermod -a -G video llama-swap
|
||||||
|
fi
|
||||||
|
if getent group docker >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to docker group..."
|
||||||
|
$SUDO usermod -a -G docker llama-swap
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Adding current user to llama-swap group..."
|
||||||
|
$SUDO usermod -a -G llama-swap "$(whoami)"
|
||||||
|
|
||||||
|
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||||
|
status "Creating default config.yaml..."
|
||||||
|
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
|
||||||
|
# default 15s likely to fail for default models due to downloading models
|
||||||
|
healthCheckTimeout: 60
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen2.5":
|
||||||
|
cmd: |
|
||||||
|
docker run
|
||||||
|
--rm
|
||||||
|
-p \${PORT}:8080
|
||||||
|
--name qwen2.5
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
|
cmdStop: docker stop qwen2.5
|
||||||
|
|
||||||
|
"smollm2":
|
||||||
|
cmd: |
|
||||||
|
docker run
|
||||||
|
--rm
|
||||||
|
-p \${PORT}:8080
|
||||||
|
--name smollm2
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
|
cmdStop: docker stop smollm2
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Creating llama-swap systemd service..."
|
||||||
|
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
|
||||||
|
[Unit]
|
||||||
|
Description=llama-swap
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
User=llama-swap
|
||||||
|
Group=llama-swap
|
||||||
|
|
||||||
|
# set this to match your environment
|
||||||
|
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config -listen ${LLAMA_SWAP_DEFAULT_ADDRESS}
|
||||||
|
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=3
|
||||||
|
StartLimitBurst=3
|
||||||
|
StartLimitInterval=30
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
EOF
|
||||||
|
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
|
||||||
|
case $SYSTEMCTL_RUNNING in
|
||||||
|
running|degraded)
|
||||||
|
status "Enabling and starting llama-swap service..."
|
||||||
|
$SUDO systemctl daemon-reload
|
||||||
|
$SUDO systemctl enable llama-swap
|
||||||
|
|
||||||
|
start_service() { $SUDO systemctl restart llama-swap; }
|
||||||
|
trap start_service EXIT
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
warning "systemd is not running"
|
||||||
|
if [ "$IS_WSL2" = true ]; then
|
||||||
|
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
if available systemctl; then
|
||||||
|
configure_systemd
|
||||||
|
fi
|
||||||
|
|
||||||
|
install_success() {
|
||||||
|
status "The llama-swap API is now available at http://${LLAMA_SWAP_DEFAULT_ADDRESS}"
|
||||||
|
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
|
||||||
|
status 'Install complete.'
|
||||||
|
}
|
||||||
|
|
||||||
|
# WSL2 only supports GPUs via nvidia passthrough
|
||||||
|
# so check for nvidia-smi to determine if GPU is available
|
||||||
|
if [ "$IS_WSL2" = true ]; then
|
||||||
|
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||||
|
status "Nvidia GPU detected."
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
install_success
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
# This script uninstalls llama-swap on Linux.
|
||||||
|
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||||
|
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||||
|
|
||||||
|
status() { echo ">>> $*" >&2; }
|
||||||
|
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||||
|
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||||
|
|
||||||
|
available() { command -v $1 >/dev/null; }
|
||||||
|
|
||||||
|
SUDO=
|
||||||
|
if [ "$(id -u)" -ne 0 ]; then
|
||||||
|
if ! available sudo; then
|
||||||
|
error "This script requires superuser permissions. Please re-run as root."
|
||||||
|
fi
|
||||||
|
|
||||||
|
SUDO="sudo"
|
||||||
|
fi
|
||||||
|
|
||||||
|
configure_systemd() {
|
||||||
|
status "Stopping llama-swap service..."
|
||||||
|
$SUDO systemctl stop llama-swap
|
||||||
|
|
||||||
|
status "Disabling llama-swap service..."
|
||||||
|
$SUDO systemctl disable llama-swap
|
||||||
|
}
|
||||||
|
if available systemctl; then
|
||||||
|
configure_systemd
|
||||||
|
fi
|
||||||
|
|
||||||
|
if available llama-swap; then
|
||||||
|
status "Removing llama-swap binary..."
|
||||||
|
$SUDO rm $(which llama-swap)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||||
|
while true; do
|
||||||
|
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
|
||||||
|
read answer
|
||||||
|
case "$answer" in
|
||||||
|
[Yy]* )
|
||||||
|
$SUDO rm -r /usr/share/llama-swap
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
[Nn]* | "" )
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
* )
|
||||||
|
echo "Invalid input. Please enter y or n."
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if id llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Removing llama-swap user..."
|
||||||
|
$SUDO userdel llama-swap
|
||||||
|
fi
|
||||||
|
|
||||||
|
if getent group llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Removing llama-swap group..."
|
||||||
|
$SUDO groupdel llama-swap
|
||||||
|
fi
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
.vite
|
||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# React + TypeScript + Vite
|
||||||
|
|
||||||
|
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
||||||
|
|
||||||
|
Currently, two official plugins are available:
|
||||||
|
|
||||||
|
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
|
||||||
|
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
|
||||||
|
|
||||||
|
## Expanding the ESLint configuration
|
||||||
|
|
||||||
|
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
|
||||||
|
|
||||||
|
```js
|
||||||
|
export default tseslint.config({
|
||||||
|
extends: [
|
||||||
|
// Remove ...tseslint.configs.recommended and replace with this
|
||||||
|
...tseslint.configs.recommendedTypeChecked,
|
||||||
|
// Alternatively, use this for stricter rules
|
||||||
|
...tseslint.configs.strictTypeChecked,
|
||||||
|
// Optionally, add this for stylistic rules
|
||||||
|
...tseslint.configs.stylisticTypeChecked,
|
||||||
|
],
|
||||||
|
languageOptions: {
|
||||||
|
// other options...
|
||||||
|
parserOptions: {
|
||||||
|
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
||||||
|
tsconfigRootDir: import.meta.dirname,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
|
||||||
|
|
||||||
|
```js
|
||||||
|
// eslint.config.js
|
||||||
|
import reactX from 'eslint-plugin-react-x'
|
||||||
|
import reactDom from 'eslint-plugin-react-dom'
|
||||||
|
|
||||||
|
export default tseslint.config({
|
||||||
|
plugins: {
|
||||||
|
// Add the react-x and react-dom plugins
|
||||||
|
'react-x': reactX,
|
||||||
|
'react-dom': reactDom,
|
||||||
|
},
|
||||||
|
rules: {
|
||||||
|
// other rules...
|
||||||
|
// Enable its recommended typescript rules
|
||||||
|
...reactX.configs['recommended-typescript'].rules,
|
||||||
|
...reactDom.configs.recommended.rules,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
import js from '@eslint/js'
|
||||||
|
import globals from 'globals'
|
||||||
|
import reactHooks from 'eslint-plugin-react-hooks'
|
||||||
|
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||||
|
import tseslint from 'typescript-eslint'
|
||||||
|
|
||||||
|
export default tseslint.config(
|
||||||
|
{ ignores: ['dist'] },
|
||||||
|
{
|
||||||
|
extends: [js.configs.recommended, ...tseslint.configs.recommended],
|
||||||
|
files: ['**/*.{ts,tsx}'],
|
||||||
|
languageOptions: {
|
||||||
|
ecmaVersion: 2020,
|
||||||
|
globals: globals.browser,
|
||||||
|
},
|
||||||
|
plugins: {
|
||||||
|
'react-hooks': reactHooks,
|
||||||
|
'react-refresh': reactRefresh,
|
||||||
|
},
|
||||||
|
rules: {
|
||||||
|
...reactHooks.configs.recommended.rules,
|
||||||
|
'react-refresh/only-export-components': [
|
||||||
|
'warn',
|
||||||
|
{ allowConstantExport: true },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<link rel="icon" type="image/png" href="/favicon-96x96.png" sizes="96x96" />
|
||||||
|
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
||||||
|
<link rel="shortcut icon" href="/favicon.ico" />
|
||||||
|
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
|
||||||
|
<link rel="manifest" href="/site.webmanifest" />
|
||||||
|
<title>llama-swap</title>
|
||||||
|
</head>
|
||||||
|
<body >
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
{
|
||||||
|
"name": "ui",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc -b && vite build --emptyOutDir",
|
||||||
|
"lint": "eslint .",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@tailwindcss/vite": "^4.1.8",
|
||||||
|
"@tanstack/react-query": "^5.80.6",
|
||||||
|
"react": "^19.1.0",
|
||||||
|
"react-dom": "^19.1.0",
|
||||||
|
"react-icons": "^5.5.0",
|
||||||
|
"react-resizable-panels": "^3.0.4",
|
||||||
|
"react-router-dom": "^7.6.2",
|
||||||
|
"tailwindcss": "^4.1.8"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@eslint/js": "^9.25.0",
|
||||||
|
"@types/react": "^19.1.2",
|
||||||
|
"@types/react-dom": "^19.1.2",
|
||||||
|
"@vitejs/plugin-react": "^4.4.1",
|
||||||
|
"eslint": "^9.25.0",
|
||||||
|
"eslint-plugin-react-hooks": "^5.2.0",
|
||||||
|
"eslint-plugin-react-refresh": "^0.4.19",
|
||||||
|
"globals": "^16.0.0",
|
||||||
|
"typescript": "~5.8.3",
|
||||||
|
"typescript-eslint": "^8.30.1",
|
||||||
|
"vite": "^6.3.5"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
After Width: | Height: | Size: 5.9 KiB |
|
After Width: | Height: | Size: 2.2 KiB |
|
After Width: | Height: | Size: 15 KiB |
|
After Width: | Height: | Size: 38 KiB |
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"name": "llama-swap",
|
||||||
|
"short_name": "llama-swap",
|
||||||
|
"icons": [
|
||||||
|
{
|
||||||
|
"src": "/web-app-manifest-192x192.png",
|
||||||
|
"sizes": "192x192",
|
||||||
|
"type": "image/png",
|
||||||
|
"purpose": "maskable"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"src": "/web-app-manifest-512x512.png",
|
||||||
|
"sizes": "512x512",
|
||||||
|
"type": "image/png",
|
||||||
|
"purpose": "maskable"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"theme_color": "#ffffff",
|
||||||
|
"background_color": "#ffffff",
|
||||||
|
"display": "standalone"
|
||||||
|
}
|
||||||
|
After Width: | Height: | Size: 6.5 KiB |
|
After Width: | Height: | Size: 28 KiB |
@@ -0,0 +1,6 @@
|
|||||||
|
#root {
|
||||||
|
max-width: 1280px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 2rem;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import { useEffect, useCallback } from "react";
|
||||||
|
import { BrowserRouter as Router, Routes, Route, Navigate, NavLink } from "react-router-dom";
|
||||||
|
import { useTheme } from "./contexts/ThemeProvider";
|
||||||
|
import { useAPI } from "./contexts/APIProvider";
|
||||||
|
import LogViewerPage from "./pages/LogViewer";
|
||||||
|
import ModelPage from "./pages/Models";
|
||||||
|
import ActivityPage from "./pages/Activity";
|
||||||
|
import ConnectionStatusIcon from "./components/ConnectionStatus";
|
||||||
|
import { RiSunFill, RiMoonFill } from "react-icons/ri";
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
const { isNarrow, toggleTheme, isDarkMode, appTitle, setAppTitle, setConnectionState } = useTheme();
|
||||||
|
const handleTitleChange = useCallback(
|
||||||
|
(newTitle: string) => {
|
||||||
|
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
|
||||||
|
},
|
||||||
|
[setAppTitle]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { connectionStatus } = useAPI();
|
||||||
|
|
||||||
|
// Synchronize the window.title connections state with the actual connection state
|
||||||
|
useEffect(() => {
|
||||||
|
setConnectionState(connectionStatus);
|
||||||
|
}, [connectionStatus]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Router basename="/ui/">
|
||||||
|
<div className="flex flex-col h-screen">
|
||||||
|
<nav className="bg-surface border-b border-border p-2 h-[75px]">
|
||||||
|
<div className="flex items-center justify-between mx-auto px-4 h-full">
|
||||||
|
{!isNarrow && (
|
||||||
|
<h1
|
||||||
|
contentEditable
|
||||||
|
suppressContentEditableWarning
|
||||||
|
className="flex items-center p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded px-1"
|
||||||
|
onBlur={(e) => handleTitleChange(e.currentTarget.textContent || "(set title)")}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === "Enter") {
|
||||||
|
e.preventDefault();
|
||||||
|
handleTitleChange(e.currentTarget.textContent || "(set title)");
|
||||||
|
e.currentTarget.blur();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{appTitle}
|
||||||
|
</h1>
|
||||||
|
)}
|
||||||
|
<div className="flex items-center space-x-4">
|
||||||
|
<NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||||
|
Logs
|
||||||
|
</NavLink>
|
||||||
|
<NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||||
|
Models
|
||||||
|
</NavLink>
|
||||||
|
<NavLink to="/activity" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
|
||||||
|
Activity
|
||||||
|
</NavLink>
|
||||||
|
<button className="" onClick={toggleTheme}>
|
||||||
|
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
|
||||||
|
</button>
|
||||||
|
<ConnectionStatusIcon />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</nav>
|
||||||
|
|
||||||
|
<main className="flex-1 overflow-auto p-4">
|
||||||
|
<Routes>
|
||||||
|
<Route path="/" element={<LogViewerPage />} />
|
||||||
|
<Route path="/models" element={<ModelPage />} />
|
||||||
|
<Route path="/activity" element={<ActivityPage />} />
|
||||||
|
<Route path="*" element={<Navigate to="/" replace />} />
|
||||||
|
</Routes>
|
||||||
|
</main>
|
||||||
|
</div>
|
||||||
|
</Router>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App;
|
||||||
|
After Width: | Height: | Size: 12 KiB |
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="35.93" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 228"><path fill="#00D8FF" d="M210.483 73.824a171.49 171.49 0 0 0-8.24-2.597c.465-1.9.893-3.777 1.273-5.621c6.238-30.281 2.16-54.676-11.769-62.708c-13.355-7.7-35.196.329-57.254 19.526a171.23 171.23 0 0 0-6.375 5.848a155.866 155.866 0 0 0-4.241-3.917C100.759 3.829 77.587-4.822 63.673 3.233C50.33 10.957 46.379 33.89 51.995 62.588a170.974 170.974 0 0 0 1.892 8.48c-3.28.932-6.445 1.924-9.474 2.98C17.309 83.498 0 98.307 0 113.668c0 15.865 18.582 31.778 46.812 41.427a145.52 145.52 0 0 0 6.921 2.165a167.467 167.467 0 0 0-2.01 9.138c-5.354 28.2-1.173 50.591 12.134 58.266c13.744 7.926 36.812-.22 59.273-19.855a145.567 145.567 0 0 0 5.342-4.923a168.064 168.064 0 0 0 6.92 6.314c21.758 18.722 43.246 26.282 56.54 18.586c13.731-7.949 18.194-32.003 12.4-61.268a145.016 145.016 0 0 0-1.535-6.842c1.62-.48 3.21-.974 4.76-1.488c29.348-9.723 48.443-25.443 48.443-41.52c0-15.417-17.868-30.326-45.517-39.844Zm-6.365 70.984c-1.4.463-2.836.91-4.3 1.345c-3.24-10.257-7.612-21.163-12.963-32.432c5.106-11 9.31-21.767 12.459-31.957c2.619.758 5.16 1.557 7.61 2.4c23.69 8.156 38.14 20.213 38.14 29.504c0 9.896-15.606 22.743-40.946 31.14Zm-10.514 20.834c2.562 12.94 2.927 24.64 1.23 33.787c-1.524 8.219-4.59 13.698-8.382 15.893c-8.067 4.67-25.32-1.4-43.927-17.412a156.726 156.726 0 0 1-6.437-5.87c7.214-7.889 14.423-17.06 21.459-27.246c12.376-1.098 24.068-2.894 34.671-5.345a134.17 134.17 0 0 1 1.386 6.193ZM87.276 214.515c-7.882 2.783-14.16 2.863-17.955.675c-8.075-4.657-11.432-22.636-6.853-46.752a156.923 156.923 0 0 1 1.869-8.499c10.486 2.32 22.093 3.988 34.498 4.994c7.084 9.967 14.501 19.128 21.976 27.15a134.668 134.668 0 0 1-4.877 4.492c-9.933 8.682-19.886 14.842-28.658 17.94ZM50.35 144.747c-12.483-4.267-22.792-9.812-29.858-15.863c-6.35-5.437-9.555-10.836-9.555-15.216c0-9.322 13.897-21.212 37.076-29.293c2.813-.98 5.757-1.905 8.812-2.773c3.204 10.42 7.406 21.315 12.477 32.332c-5.137 11.18-9.399 22.249-12.634 32.792a134.718 134.718 0 0 1-6.318-1.979Zm12.378-84.26c-4.811-24.587-1.616-43.134 6.425-47.789c8.564-4.958 27.502 2.111 47.463 19.835a144.318 144.318 0 0 1 3.841 3.545c-7.438 7.987-14.787 17.08-21.808 26.988c-12.04 1.116-23.565 2.908-34.161 5.309a160.342 160.342 0 0 1-1.76-7.887Zm110.427 27.268a347.8 347.8 0 0 0-7.785-12.803c8.168 1.033 15.994 2.404 23.343 4.08c-2.206 7.072-4.956 14.465-8.193 22.045a381.151 381.151 0 0 0-7.365-13.322Zm-45.032-43.861c5.044 5.465 10.096 11.566 15.065 18.186a322.04 322.04 0 0 0-30.257-.006c4.974-6.559 10.069-12.652 15.192-18.18ZM82.802 87.83a323.167 323.167 0 0 0-7.227 13.238c-3.184-7.553-5.909-14.98-8.134-22.152c7.304-1.634 15.093-2.97 23.209-3.984a321.524 321.524 0 0 0-7.848 12.897Zm8.081 65.352c-8.385-.936-16.291-2.203-23.593-3.793c2.26-7.3 5.045-14.885 8.298-22.6a321.187 321.187 0 0 0 7.257 13.246c2.594 4.48 5.28 8.868 8.038 13.147Zm37.542 31.03c-5.184-5.592-10.354-11.779-15.403-18.433c4.902.192 9.899.29 14.978.29c5.218 0 10.376-.117 15.453-.343c-4.985 6.774-10.018 12.97-15.028 18.486Zm52.198-57.817c3.422 7.8 6.306 15.345 8.596 22.52c-7.422 1.694-15.436 3.058-23.88 4.071a382.417 382.417 0 0 0 7.859-13.026a347.403 347.403 0 0 0 7.425-13.565Zm-16.898 8.101a358.557 358.557 0 0 1-12.281 19.815a329.4 329.4 0 0 1-23.444.823c-7.967 0-15.716-.248-23.178-.732a310.202 310.202 0 0 1-12.513-19.846h.001a307.41 307.41 0 0 1-10.923-20.627a310.278 310.278 0 0 1 10.89-20.637l-.001.001a307.318 307.318 0 0 1 12.413-19.761c7.613-.576 15.42-.876 23.31-.876H128c7.926 0 15.743.303 23.354.883a329.357 329.357 0 0 1 12.335 19.695a358.489 358.489 0 0 1 11.036 20.54a329.472 329.472 0 0 1-11 20.722Zm22.56-122.124c8.572 4.944 11.906 24.881 6.52 51.026c-.344 1.668-.73 3.367-1.15 5.09c-10.622-2.452-22.155-4.275-34.23-5.408c-7.034-10.017-14.323-19.124-21.64-27.008a160.789 160.789 0 0 1 5.888-5.4c18.9-16.447 36.564-22.941 44.612-18.3ZM128 90.808c12.625 0 22.86 10.235 22.86 22.86s-10.235 22.86-22.86 22.86s-22.86-10.235-22.86-22.86s10.235-22.86 22.86-22.86Z"></path></svg>
|
||||||
|
After Width: | Height: | Size: 4.0 KiB |
@@ -0,0 +1,26 @@
|
|||||||
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
|
import { useMemo } from "react";
|
||||||
|
|
||||||
|
const ConnectionStatusIcon = () => {
|
||||||
|
const { connectionStatus } = useAPI();
|
||||||
|
|
||||||
|
const eventStatusColor = useMemo(() => {
|
||||||
|
switch (connectionStatus) {
|
||||||
|
case "connected":
|
||||||
|
return "bg-green-500";
|
||||||
|
case "connecting":
|
||||||
|
return "bg-yellow-500";
|
||||||
|
case "disconnected":
|
||||||
|
default:
|
||||||
|
return "bg-red-500";
|
||||||
|
}
|
||||||
|
}, [connectionStatus]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center" title={`event stream: ${connectionStatus}`}>
|
||||||
|
<span className={`inline-block w-3 h-3 rounded-full ${eventStatusColor} mr-2`}></span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ConnectionStatusIcon;
|
||||||
@@ -0,0 +1,227 @@
|
|||||||
|
import { useRef, createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
|
||||||
|
import type { ConnectionState } from "../lib/types";
|
||||||
|
|
||||||
|
type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
|
||||||
|
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||||
|
|
||||||
|
export interface Model {
|
||||||
|
id: string;
|
||||||
|
state: ModelStatus;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
unlisted: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface APIProviderType {
|
||||||
|
models: Model[];
|
||||||
|
listModels: () => Promise<Model[]>;
|
||||||
|
unloadAllModels: () => Promise<void>;
|
||||||
|
loadModel: (model: string) => Promise<void>;
|
||||||
|
enableAPIEvents: (enabled: boolean) => void;
|
||||||
|
proxyLogs: string;
|
||||||
|
upstreamLogs: string;
|
||||||
|
metrics: Metrics[];
|
||||||
|
connectionStatus: ConnectionState;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Metrics {
|
||||||
|
id: number;
|
||||||
|
timestamp: string;
|
||||||
|
model: string;
|
||||||
|
input_tokens: number;
|
||||||
|
output_tokens: number;
|
||||||
|
prompt_per_second: number;
|
||||||
|
tokens_per_second: number;
|
||||||
|
duration_ms: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface LogData {
|
||||||
|
source: "upstream" | "proxy";
|
||||||
|
data: string;
|
||||||
|
}
|
||||||
|
interface APIEventEnvelope {
|
||||||
|
type: "modelStatus" | "logData" | "metrics";
|
||||||
|
data: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const APIContext = createContext<APIProviderType | undefined>(undefined);
|
||||||
|
type APIProviderProps = {
|
||||||
|
children: ReactNode;
|
||||||
|
autoStartAPIEvents?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function APIProvider({ children, autoStartAPIEvents = true }: APIProviderProps) {
|
||||||
|
const [proxyLogs, setProxyLogs] = useState("");
|
||||||
|
const [upstreamLogs, setUpstreamLogs] = useState("");
|
||||||
|
const [metrics, setMetrics] = useState<Metrics[]>([]);
|
||||||
|
const [connectionStatus, setConnectionState] = useState<ConnectionState>("disconnected");
|
||||||
|
const apiEventSource = useRef<EventSource | null>(null);
|
||||||
|
|
||||||
|
const [models, setModels] = useState<Model[]>([]);
|
||||||
|
|
||||||
|
const appendLog = useCallback((newData: string, setter: React.Dispatch<React.SetStateAction<string>>) => {
|
||||||
|
setter((prev) => {
|
||||||
|
const updatedLog = prev + newData;
|
||||||
|
return updatedLog.length > LOG_LENGTH_LIMIT ? updatedLog.slice(-LOG_LENGTH_LIMIT) : updatedLog;
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const enableAPIEvents = useCallback((enabled: boolean) => {
|
||||||
|
if (!enabled) {
|
||||||
|
apiEventSource.current?.close();
|
||||||
|
apiEventSource.current = null;
|
||||||
|
setMetrics([]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let retryCount = 0;
|
||||||
|
const initialDelay = 1000; // 1 second
|
||||||
|
|
||||||
|
const connect = () => {
|
||||||
|
apiEventSource.current = null;
|
||||||
|
const eventSource = new EventSource("/api/events");
|
||||||
|
setConnectionState("connecting");
|
||||||
|
|
||||||
|
eventSource.onopen = () => {
|
||||||
|
// clear everything out on connect to keep things in sync
|
||||||
|
setProxyLogs("");
|
||||||
|
setUpstreamLogs("");
|
||||||
|
setMetrics([]); // clear metrics on reconnect
|
||||||
|
setModels([]); // clear models on reconnect
|
||||||
|
apiEventSource.current = eventSource;
|
||||||
|
retryCount = 0;
|
||||||
|
setConnectionState("connected");
|
||||||
|
};
|
||||||
|
|
||||||
|
eventSource.onmessage = (e: MessageEvent) => {
|
||||||
|
try {
|
||||||
|
const message = JSON.parse(e.data) as APIEventEnvelope;
|
||||||
|
switch (message.type) {
|
||||||
|
case "modelStatus":
|
||||||
|
{
|
||||||
|
const models = JSON.parse(message.data) as Model[];
|
||||||
|
|
||||||
|
// sort models by name and id
|
||||||
|
models.sort((a, b) => {
|
||||||
|
return (a.name + a.id).localeCompare(b.name + b.id);
|
||||||
|
});
|
||||||
|
|
||||||
|
setModels(models);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "logData":
|
||||||
|
const logData = JSON.parse(message.data) as LogData;
|
||||||
|
switch (logData.source) {
|
||||||
|
case "proxy":
|
||||||
|
appendLog(logData.data, setProxyLogs);
|
||||||
|
break;
|
||||||
|
case "upstream":
|
||||||
|
appendLog(logData.data, setUpstreamLogs);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "metrics":
|
||||||
|
{
|
||||||
|
const newMetrics = JSON.parse(message.data) as Metrics[];
|
||||||
|
setMetrics((prevMetrics) => {
|
||||||
|
return [...newMetrics, ...prevMetrics];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error(e.data, err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
eventSource.onerror = () => {
|
||||||
|
eventSource.close();
|
||||||
|
retryCount++;
|
||||||
|
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
|
||||||
|
setConnectionState("disconnected");
|
||||||
|
setTimeout(connect, delay);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
connect();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (autoStartAPIEvents) {
|
||||||
|
enableAPIEvents(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
enableAPIEvents(false);
|
||||||
|
};
|
||||||
|
}, [enableAPIEvents, autoStartAPIEvents]);
|
||||||
|
|
||||||
|
const listModels = useCallback(async (): Promise<Model[]> => {
|
||||||
|
try {
|
||||||
|
const response = await fetch("/api/models/");
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`HTTP error! status: ${response.status}`);
|
||||||
|
}
|
||||||
|
const data = await response.json();
|
||||||
|
return data || [];
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to fetch models:", error);
|
||||||
|
return []; // Return empty array as fallback
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const unloadAllModels = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/api/models/unload/`, {
|
||||||
|
method: "POST",
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to unload models: ${response.status}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to unload models:", error);
|
||||||
|
throw error; // Re-throw to let calling code handle it
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const loadModel = useCallback(async (model: string) => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/upstream/${model}/`, {
|
||||||
|
method: "GET",
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to load model: ${response.status}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to load model:", error);
|
||||||
|
throw error; // Re-throw to let calling code handle it
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const value = useMemo(
|
||||||
|
() => ({
|
||||||
|
models,
|
||||||
|
listModels,
|
||||||
|
unloadAllModels,
|
||||||
|
loadModel,
|
||||||
|
enableAPIEvents,
|
||||||
|
proxyLogs,
|
||||||
|
upstreamLogs,
|
||||||
|
metrics,
|
||||||
|
connectionStatus,
|
||||||
|
}),
|
||||||
|
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
|
||||||
|
);
|
||||||
|
|
||||||
|
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useAPI() {
|
||||||
|
const context = useContext(APIContext);
|
||||||
|
if (context === undefined) {
|
||||||
|
throw new Error("useAPI must be used within an APIProvider");
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
import { createContext, useContext, useEffect, type ReactNode, useMemo, useState } from "react";
|
||||||
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
|
import type { ConnectionState } from "../lib/types";
|
||||||
|
|
||||||
|
type ScreenWidth = "xs" | "sm" | "md" | "lg" | "xl" | "2xl";
|
||||||
|
type ThemeContextType = {
|
||||||
|
isDarkMode: boolean;
|
||||||
|
screenWidth: ScreenWidth;
|
||||||
|
isNarrow: boolean;
|
||||||
|
toggleTheme: () => void;
|
||||||
|
|
||||||
|
// for managing the window title and connection state information
|
||||||
|
appTitle: string;
|
||||||
|
setAppTitle: (title: string) => void;
|
||||||
|
setConnectionState: (state: ConnectionState) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ThemeContext = createContext<ThemeContextType | undefined>(undefined);
|
||||||
|
|
||||||
|
type ThemeProviderProps = {
|
||||||
|
children: ReactNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function ThemeProvider({ children }: ThemeProviderProps) {
|
||||||
|
const [appTitle, setAppTitle] = usePersistentState("app-title", "llama-swap");
|
||||||
|
const [connectionState, setConnectionState] = useState<ConnectionState>("disconnected");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the document.title with informative information
|
||||||
|
*/
|
||||||
|
useEffect(() => {
|
||||||
|
const connectionIcon = connectionState === "connecting" ? "🟡" : connectionState === "connected" ? "🟢" : "🔴";
|
||||||
|
document.title = connectionIcon + " " + appTitle; // Set initial title
|
||||||
|
}, [appTitle, connectionState]);
|
||||||
|
|
||||||
|
const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false);
|
||||||
|
const [screenWidth, setScreenWidth] = useState<ScreenWidth>("md"); // Default to md
|
||||||
|
|
||||||
|
// matches tailwind classes
|
||||||
|
// https://tailwindcss.com/docs/responsive-design
|
||||||
|
useEffect(() => {
|
||||||
|
const checkInnerWidth = () => {
|
||||||
|
const innerWidth = window.innerWidth;
|
||||||
|
if (innerWidth < 640) {
|
||||||
|
setScreenWidth("xs");
|
||||||
|
} else if (innerWidth < 768) {
|
||||||
|
setScreenWidth("sm");
|
||||||
|
} else if (innerWidth < 1024) {
|
||||||
|
setScreenWidth("md");
|
||||||
|
} else if (innerWidth < 1280) {
|
||||||
|
setScreenWidth("lg");
|
||||||
|
} else if (innerWidth < 1536) {
|
||||||
|
setScreenWidth("xl");
|
||||||
|
} else {
|
||||||
|
setScreenWidth("2xl");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
checkInnerWidth();
|
||||||
|
window.addEventListener("resize", checkInnerWidth);
|
||||||
|
|
||||||
|
return () => window.removeEventListener("resize", checkInnerWidth);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light");
|
||||||
|
}, [isDarkMode]);
|
||||||
|
|
||||||
|
const toggleTheme = () => setIsDarkMode((prev) => !prev);
|
||||||
|
const isNarrow = useMemo(() => {
|
||||||
|
return screenWidth === "xs" || screenWidth === "sm" || screenWidth === "md";
|
||||||
|
}, [screenWidth]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ThemeContext.Provider
|
||||||
|
value={{
|
||||||
|
isDarkMode,
|
||||||
|
toggleTheme,
|
||||||
|
screenWidth,
|
||||||
|
isNarrow,
|
||||||
|
appTitle,
|
||||||
|
setAppTitle,
|
||||||
|
setConnectionState,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</ThemeContext.Provider>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useTheme(): ThemeContextType {
|
||||||
|
const context = useContext(ThemeContext);
|
||||||
|
if (context === undefined) {
|
||||||
|
throw new Error("useTheme must be used within a ThemeProvider");
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
import { useState, useEffect, useCallback } from "react";
|
||||||
|
|
||||||
|
export function usePersistentState<T>(key: string, initialValue: T): [T, (value: T | ((prevState: T) => T)) => void] {
|
||||||
|
const [state, setState] = useState<T>(() => {
|
||||||
|
if (typeof window === "undefined") return initialValue;
|
||||||
|
try {
|
||||||
|
const saved = localStorage.getItem(key);
|
||||||
|
return saved !== null ? JSON.parse(saved) : initialValue;
|
||||||
|
} catch (e) {
|
||||||
|
console.error(`Error parsing stored value for ${key}`, e);
|
||||||
|
return initialValue;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const setPersistentState = useCallback(
|
||||||
|
(value: T | ((prevState: T) => T)) => {
|
||||||
|
setState((prev) => {
|
||||||
|
const nextValue = typeof value === "function" ? (value as (prevState: T) => T)(prev) : value;
|
||||||
|
try {
|
||||||
|
localStorage.setItem(key, JSON.stringify(nextValue));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(`Error saving value for ${key}`, e);
|
||||||
|
}
|
||||||
|
return nextValue;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[key]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
try {
|
||||||
|
localStorage.setItem(key, JSON.stringify(state));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(`Error saving value for ${key}`, e);
|
||||||
|
}
|
||||||
|
}, [key, state]);
|
||||||
|
|
||||||
|
return [state, setPersistentState];
|
||||||
|
}
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
@import "tailwindcss";
|
||||||
|
@custom-variant dark (&:where([data-theme=dark], [data-theme=dark] *));
|
||||||
|
|
||||||
|
@theme {
|
||||||
|
--color-background: rgba(252, 252, 249, 1);
|
||||||
|
--color-surface: rgba(255, 255, 253, 1);
|
||||||
|
|
||||||
|
/* text colors */
|
||||||
|
--color-txtmain: rgba(19, 52, 59, 1);
|
||||||
|
--color-txtsecondary: rgba(98, 108, 113, 1);
|
||||||
|
--color-navlink-active: rgba(245, 245, 245, 1);
|
||||||
|
|
||||||
|
--color-primary: rgba(50, 184, 198, 1);
|
||||||
|
|
||||||
|
--color-primary-hover: rgba(29, 116, 128, 1);
|
||||||
|
--color-primary-active: rgba(26, 104, 115, 1);
|
||||||
|
--color-secondary: rgba(94, 82, 64, 0.12);
|
||||||
|
--color-secondary-hover: rgba(94, 82, 64, 0.2);
|
||||||
|
--color-secondary-active: rgba(94, 82, 64, 0.25);
|
||||||
|
--color-border: rgba(94, 82, 64, 0.3);
|
||||||
|
--color-btn-primary-text: rgba(252, 252, 249, 1);
|
||||||
|
--color-card-border: rgba(94, 82, 64, 0.12);
|
||||||
|
--color-card-border-inner: rgba(94, 82, 64, 0.12);
|
||||||
|
--color-error: rgba(192, 21, 47, 1);
|
||||||
|
--color-success: rgba(33, 128, 141, 1);
|
||||||
|
--color-warning: rgb(244, 155, 0);
|
||||||
|
--color-info: rgba(98, 108, 113, 1);
|
||||||
|
--color-focus-ring: rgba(33, 128, 141, 0.4);
|
||||||
|
--color-select-caret: rgba(19, 52, 59, 0.8);
|
||||||
|
--color-btn-border: rgba(94, 82, 64, 0.7);
|
||||||
|
}
|
||||||
|
|
||||||
|
@layer theme {
|
||||||
|
/* over ride theme for dark mode */
|
||||||
|
[data-theme="dark"] {
|
||||||
|
--color-background: rgba(31, 33, 33, 1);
|
||||||
|
--color-surface: rgba(38, 40, 40, 1);
|
||||||
|
/* text colors */
|
||||||
|
--color-txtmain: rgba(245, 245, 245, 1);
|
||||||
|
--color-txtsecondary: rgba(167, 169, 169, 0.7);
|
||||||
|
|
||||||
|
--color-navlink-active: rgba(245, 245, 245, 1);
|
||||||
|
|
||||||
|
--color-primary: rgba(33, 128, 141, 1);
|
||||||
|
--color-primary-hover: rgba(45, 166, 178, 1);
|
||||||
|
--color-primary-active: rgba(41, 150, 161, 1);
|
||||||
|
--color-secondary: rgba(119, 124, 124, 0.15);
|
||||||
|
--color-secondary-hover: rgba(119, 124, 124, 0.25);
|
||||||
|
--color-secondary-active: rgba(119, 124, 124, 0.3);
|
||||||
|
--color-border: rgba(119, 124, 124, 0.3);
|
||||||
|
--color-error: rgba(255, 84, 89, 1);
|
||||||
|
--color-success: rgba(50, 184, 198, 1);
|
||||||
|
--color-warning: rgb(244, 155, 0);
|
||||||
|
--color-info: rgba(167, 169, 169, 1);
|
||||||
|
--color-focus-ring: rgba(50, 184, 198, 0.4);
|
||||||
|
--color-btn-primary-text: rgba(19, 52, 59, 1);
|
||||||
|
--color-card-border: rgba(119, 124, 124, 0.2);
|
||||||
|
--color-card-border-inner: rgba(119, 124, 124, 0.15);
|
||||||
|
--shadow-inset-sm: inset 0 1px 0 rgba(255, 255, 255, 0.1), inset 0 -1px 0 rgba(0, 0, 0, 0.15);
|
||||||
|
--button-border-secondary: rgba(119, 124, 124, 0.2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@layer base {
|
||||||
|
body {
|
||||||
|
/* example of how colors using theme colors*/
|
||||||
|
@apply bg-background text-txtmain;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
@apply text-4xl text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h2 {
|
||||||
|
@apply text-3xl text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h3 {
|
||||||
|
@apply text-2xl text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h4 {
|
||||||
|
@apply text-xl text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h5 {
|
||||||
|
@apply text-lg text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h6 {
|
||||||
|
@apply text-base text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* define CSS classes here for specific types of components */
|
||||||
|
@layer components {
|
||||||
|
.container {
|
||||||
|
@apply px-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Navigation Header */
|
||||||
|
|
||||||
|
.navlink {
|
||||||
|
@apply text-txtsecondary hover:bg-secondary hover:text-txtmain rounded-lg p-2;
|
||||||
|
}
|
||||||
|
.navlink.active {
|
||||||
|
@apply bg-primary text-navlink-active;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Card component */
|
||||||
|
.card {
|
||||||
|
@apply bg-surface rounded-lg border border-card-border shadow-sm overflow-hidden p-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card:hover {
|
||||||
|
@apply shadow-md;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card__body {
|
||||||
|
@apply p-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card__header,
|
||||||
|
.card__footer {
|
||||||
|
@apply p-4 border-b border-card-border-inner;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Status Badges */
|
||||||
|
.status {
|
||||||
|
@apply inline-block px-2 py-1 text-xs font-medium rounded-full;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status--ready {
|
||||||
|
@apply bg-success/10 text-success;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status--starting,
|
||||||
|
.status--stopping {
|
||||||
|
@apply bg-warning/10 text-warning;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status--stopped {
|
||||||
|
@apply bg-error/10 text-error;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Buttons */
|
||||||
|
.btn {
|
||||||
|
@apply bg-surface p-2 px-4 text-sm rounded-full border border-2 transition-colors duration-200 border-btn-border;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn:hover {
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn--sm {
|
||||||
|
@apply px-2 py-0.5 text-xs;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn:disabled {
|
||||||
|
@apply opacity-50 cursor-not-allowed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@layer utilities {
|
||||||
|
.ml-2 {
|
||||||
|
margin-left: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.my-8 {
|
||||||
|
margin-top: 2rem;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
export type ConnectionState = "connected" | "connecting" | "disconnected";
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import { StrictMode } from "react";
|
||||||
|
import { createRoot } from "react-dom/client";
|
||||||
|
import "./index.css";
|
||||||
|
import App from "./App.tsx";
|
||||||
|
import { ThemeProvider } from "./contexts/ThemeProvider";
|
||||||
|
import { APIProvider } from "./contexts/APIProvider";
|
||||||
|
|
||||||
|
createRoot(document.getElementById("root")!).render(
|
||||||
|
<StrictMode>
|
||||||
|
<ThemeProvider>
|
||||||
|
<APIProvider>
|
||||||
|
<App />
|
||||||
|
</APIProvider>
|
||||||
|
</ThemeProvider>
|
||||||
|
</StrictMode>
|
||||||
|
);
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
import { useMemo } from "react";
|
||||||
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
|
|
||||||
|
const formatTimestamp = (timestamp: string): string => {
|
||||||
|
return new Date(timestamp).toLocaleString();
|
||||||
|
};
|
||||||
|
|
||||||
|
const formatSpeed = (speed: number): string => {
|
||||||
|
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||||
|
};
|
||||||
|
|
||||||
|
const formatDuration = (ms: number): string => {
|
||||||
|
return (ms / 1000).toFixed(2) + "s";
|
||||||
|
};
|
||||||
|
|
||||||
|
const ActivityPage = () => {
|
||||||
|
const { metrics } = useAPI();
|
||||||
|
const sortedMetrics = useMemo(() => {
|
||||||
|
return [...metrics].sort((a, b) => b.id - a.id);
|
||||||
|
}, [metrics]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="p-6">
|
||||||
|
<h1 className="text-2xl font-bold mb-4">Activity</h1>
|
||||||
|
|
||||||
|
{metrics.length === 0 ? (
|
||||||
|
<div className="text-center py-8">
|
||||||
|
<p className="text-gray-600">No metrics data available</p>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="overflow-x-auto">
|
||||||
|
<table className="min-w-full divide-y">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th className="px-4 py-3 text-left text-xs font-medium uppercase tracking-wider">Id</th>
|
||||||
|
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Timestamp</th>
|
||||||
|
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Model</th>
|
||||||
|
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Input Tokens</th>
|
||||||
|
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Output Tokens</th>
|
||||||
|
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Prompt Processing</th>
|
||||||
|
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Generation Speed</th>
|
||||||
|
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Duration</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody className="divide-y">
|
||||||
|
{sortedMetrics.map((metric) => (
|
||||||
|
<tr key={`metric_${metric.id}`}>
|
||||||
|
<td className="px-4 py-4 whitespace-nowrap text-sm">{metric.id + 1 /* un-zero index */}</td>
|
||||||
|
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatTimestamp(metric.timestamp)}</td>
|
||||||
|
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.model}</td>
|
||||||
|
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.input_tokens.toLocaleString()}</td>
|
||||||
|
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.output_tokens.toLocaleString()}</td>
|
||||||
|
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.prompt_per_second)}</td>
|
||||||
|
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.tokens_per_second)}</td>
|
||||||
|
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatDuration(metric.duration_ms)}</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ActivityPage;
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
|
||||||
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
|
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||||
|
import {
|
||||||
|
RiTextWrap,
|
||||||
|
RiAlignJustify,
|
||||||
|
RiFontSize,
|
||||||
|
RiMenuSearchLine,
|
||||||
|
RiMenuSearchFill,
|
||||||
|
RiCloseCircleFill,
|
||||||
|
} from "react-icons/ri";
|
||||||
|
import { useTheme } from "../contexts/ThemeProvider";
|
||||||
|
|
||||||
|
const LogViewer = () => {
|
||||||
|
const { proxyLogs, upstreamLogs } = useAPI();
|
||||||
|
const { isNarrow } = useTheme();
|
||||||
|
const direction = isNarrow ? "vertical" : "horizontal";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PanelGroup direction={direction} className="gap-2" autoSaveId="logviewer-panel-group">
|
||||||
|
<Panel id="proxy" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||||
|
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||||
|
</Panel>
|
||||||
|
<PanelResizeHandle
|
||||||
|
className={
|
||||||
|
direction === "horizontal"
|
||||||
|
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||||
|
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<Panel id="upstream" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||||
|
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||||
|
</Panel>
|
||||||
|
</PanelGroup>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
interface LogPanelProps {
|
||||||
|
id: string;
|
||||||
|
title: string;
|
||||||
|
logData: string;
|
||||||
|
}
|
||||||
|
export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
|
||||||
|
const [filterRegex, setFilterRegex] = useState("");
|
||||||
|
const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">(
|
||||||
|
`logPanel-${id}-fontSize`,
|
||||||
|
"normal"
|
||||||
|
);
|
||||||
|
const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false);
|
||||||
|
const [showFilter, setShowFilter] = usePersistentState(`logPanel-${id}-showFilter`, false);
|
||||||
|
|
||||||
|
const textWrapClass = useMemo(() => {
|
||||||
|
return wrapText ? "whitespace-pre-wrap" : "whitespace-pre";
|
||||||
|
}, [wrapText]);
|
||||||
|
|
||||||
|
const toggleFontSize = useCallback(() => {
|
||||||
|
setFontSize((prev) => {
|
||||||
|
switch (prev) {
|
||||||
|
case "xxs":
|
||||||
|
return "xs";
|
||||||
|
case "xs":
|
||||||
|
return "small";
|
||||||
|
case "small":
|
||||||
|
return "normal";
|
||||||
|
case "normal":
|
||||||
|
return "xxs";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const toggleWrapText = useCallback(() => {
|
||||||
|
setTextWrap((prev) => !prev);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const toggleFilter = useCallback(() => {
|
||||||
|
if (showFilter) {
|
||||||
|
setShowFilter(false);
|
||||||
|
setFilterRegex(""); // Clear filter when closing
|
||||||
|
} else {
|
||||||
|
setShowFilter(true);
|
||||||
|
}
|
||||||
|
}, [filterRegex, setFilterRegex, showFilter]);
|
||||||
|
|
||||||
|
const fontSizeClass = useMemo(() => {
|
||||||
|
switch (fontSize) {
|
||||||
|
case "xxs":
|
||||||
|
return "text-[0.5rem]"; // 0.5rem (8px)
|
||||||
|
case "xs":
|
||||||
|
return "text-[0.75rem]"; // 0.75rem (12px)
|
||||||
|
case "small":
|
||||||
|
return "text-[0.875rem]"; // 0.875rem (14px)
|
||||||
|
case "normal":
|
||||||
|
return "text-base"; // 1rem (16px)
|
||||||
|
}
|
||||||
|
}, [fontSize]);
|
||||||
|
|
||||||
|
const filteredLogs = useMemo(() => {
|
||||||
|
if (!filterRegex) return logData;
|
||||||
|
try {
|
||||||
|
const regex = new RegExp(filterRegex, "i");
|
||||||
|
const lines = logData.split("\n");
|
||||||
|
const filtered = lines.filter((line) => regex.test(line));
|
||||||
|
return filtered.join("\n");
|
||||||
|
} catch (e) {
|
||||||
|
return logData; // Return unfiltered if regex is invalid
|
||||||
|
}
|
||||||
|
}, [logData, filterRegex]);
|
||||||
|
|
||||||
|
// auto scroll to bottom
|
||||||
|
const preTagRef = useRef<HTMLPreElement>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
if (!preTagRef.current) return;
|
||||||
|
preTagRef.current.scrollTop = preTagRef.current.scrollHeight;
|
||||||
|
}, [filteredLogs]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="bg-surface border border-border rounded-lg overflow-hidden flex flex-col h-full">
|
||||||
|
<div className="p-4 border-b border-border bg-secondary">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h3 className="m-0 text-lg p-0">{title}</h3>
|
||||||
|
|
||||||
|
<div className="flex gap-2 items-center">
|
||||||
|
<button className="btn" onClick={toggleFontSize}>
|
||||||
|
<RiFontSize />
|
||||||
|
</button>
|
||||||
|
<button className="btn" onClick={toggleWrapText}>
|
||||||
|
{wrapText ? <RiTextWrap /> : <RiAlignJustify />}
|
||||||
|
</button>
|
||||||
|
<button className="btn" onClick={toggleFilter}>
|
||||||
|
{showFilter ? <RiMenuSearchFill /> : <RiMenuSearchLine />}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||||
|
{showFilter && (
|
||||||
|
<div className="mt-2 w-full">
|
||||||
|
<div className="flex gap-2 items-center w-full">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
className="w-full text-sm border p-2 rounded"
|
||||||
|
placeholder="Filter logs..."
|
||||||
|
value={filterRegex}
|
||||||
|
onChange={(e) => setFilterRegex(e.target.value)}
|
||||||
|
/>
|
||||||
|
<button className="pl-2" onClick={() => setFilterRegex("")}>
|
||||||
|
<RiCloseCircleFill size="24" />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||||
|
<pre ref={preTagRef} className={`${textWrapClass} ${fontSizeClass} h-full overflow-auto p-4`}>
|
||||||
|
{filteredLogs}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
export default LogViewer;
|
||||||
@@ -0,0 +1,173 @@
|
|||||||
|
import { useState, useCallback, useMemo } from "react";
|
||||||
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
|
import { LogPanel } from "./LogViewer";
|
||||||
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
|
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||||
|
import { useTheme } from "../contexts/ThemeProvider";
|
||||||
|
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri";
|
||||||
|
|
||||||
|
export default function ModelsPage() {
|
||||||
|
const { isNarrow } = useTheme();
|
||||||
|
const direction = isNarrow ? "vertical" : "horizontal";
|
||||||
|
const { upstreamLogs } = useAPI();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PanelGroup direction={direction} className="gap-2" autoSaveId={"models-panel-group"}>
|
||||||
|
<Panel id="models" defaultSize={50} minSize={isNarrow ? 0 : 25} maxSize={100} collapsible={isNarrow}>
|
||||||
|
<ModelsPanel />
|
||||||
|
</Panel>
|
||||||
|
|
||||||
|
<PanelResizeHandle
|
||||||
|
className={
|
||||||
|
direction === "horizontal"
|
||||||
|
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||||
|
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<Panel collapsible={true} defaultSize={50} minSize={0}>
|
||||||
|
<div className="flex flex-col h-full space-y-4">
|
||||||
|
{direction === "horizontal" && <StatsPanel />}
|
||||||
|
<div className="flex-1 min-h-0">
|
||||||
|
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Panel>
|
||||||
|
</PanelGroup>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function ModelsPanel() {
|
||||||
|
const { models, loadModel, unloadAllModels } = useAPI();
|
||||||
|
const [isUnloading, setIsUnloading] = useState(false);
|
||||||
|
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||||
|
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
||||||
|
|
||||||
|
const filteredModels = useMemo(() => {
|
||||||
|
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||||
|
}, [models, showUnlisted]);
|
||||||
|
|
||||||
|
const handleUnloadAllModels = useCallback(async () => {
|
||||||
|
setIsUnloading(true);
|
||||||
|
try {
|
||||||
|
await unloadAllModels();
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
} finally {
|
||||||
|
setTimeout(() => {
|
||||||
|
setIsUnloading(false);
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
}, [unloadAllModels]);
|
||||||
|
|
||||||
|
const toggleIdorName = useCallback(() => {
|
||||||
|
setShowIdorName((prev) => (prev === "name" ? "id" : "name"));
|
||||||
|
}, [showIdorName]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="card h-full flex flex-col">
|
||||||
|
<div className="shrink-0">
|
||||||
|
<h2>Models</h2>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<button className="btn flex items-center gap-2" onClick={toggleIdorName} style={{ lineHeight: "1.2" }}>
|
||||||
|
<RiSwapBoxFill /> {showIdorName === "id" ? "ID" : "Name"}
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<button
|
||||||
|
className="btn flex items-center gap-2"
|
||||||
|
onClick={() => setShowUnlisted(!showUnlisted)}
|
||||||
|
style={{ lineHeight: "1.2" }}
|
||||||
|
>
|
||||||
|
{showUnlisted ? <RiEyeFill /> : <RiEyeOffFill />} unlisted
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<button className="btn flex items-center gap-2" onClick={handleUnloadAllModels} disabled={isUnloading}>
|
||||||
|
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex-1 overflow-y-auto">
|
||||||
|
<table className="w-full">
|
||||||
|
<thead className="sticky top-0 bg-card z-10">
|
||||||
|
<tr className="border-b border-primary bg-surface">
|
||||||
|
<th className="text-left p-2">{showIdorName === "id" ? "Model ID" : "Name"}</th>
|
||||||
|
<th className="text-left p-2"></th>
|
||||||
|
<th className="text-left p-2">State</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{filteredModels.map((model) => (
|
||||||
|
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
|
||||||
|
<td className={`p-2 ${model.unlisted ? "text-txtsecondary" : ""}`}>
|
||||||
|
<a href={`/upstream/${model.id}/`} className={`underline`} target="_blank">
|
||||||
|
{showIdorName === "id" ? model.id : model.name !== "" ? model.name : model.id}
|
||||||
|
</a>
|
||||||
|
{model.description !== "" && (
|
||||||
|
<p className={model.unlisted ? "text-opacity-70" : ""}>
|
||||||
|
<em>{model.description}</em>
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
<td className="p-2 w-[50px]">
|
||||||
|
<button
|
||||||
|
className="btn btn--sm"
|
||||||
|
disabled={model.state !== "stopped"}
|
||||||
|
onClick={() => loadModel(model.id)}
|
||||||
|
>
|
||||||
|
Load
|
||||||
|
</button>
|
||||||
|
</td>
|
||||||
|
<td className="p-2 w-[75px]">
|
||||||
|
<span className={`status status--${model.state}`}>{model.state}</span>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function StatsPanel() {
|
||||||
|
const { metrics } = useAPI();
|
||||||
|
|
||||||
|
const [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond] = useMemo(() => {
|
||||||
|
const totalRequests = metrics.length;
|
||||||
|
if (totalRequests === 0) {
|
||||||
|
return [0, 0, 0];
|
||||||
|
}
|
||||||
|
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||||
|
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||||
|
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
|
||||||
|
return [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond];
|
||||||
|
}, [metrics]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="card">
|
||||||
|
<div className="rounded-lg overflow-hidden border border-gray-200">
|
||||||
|
<table className="w-full">
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<th className="p-2 font-medium border-b border-gray-200 text-right">Requests</th>
|
||||||
|
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Processed</th>
|
||||||
|
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Generated</th>
|
||||||
|
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Tokens/Sec</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td className="p-2 text-right border-r border-gray-200">{totalRequests}</td>
|
||||||
|
<td className="p-2 text-right border-r border-gray-200">
|
||||||
|
{new Intl.NumberFormat().format(totalInputTokens)}
|
||||||
|
</td>
|
||||||
|
<td className="p-2 text-right border-r border-gray-200">
|
||||||
|
{new Intl.NumberFormat().format(totalOutputTokens)}
|
||||||
|
</td>
|
||||||
|
<td className="p-2 text-right">{avgTokensPerSecond}</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
/// <reference types="vite/client" />
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
||||||
|
"target": "ES2020",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
|
||||||
|
/* Bundler mode */
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"verbatimModuleSyntax": true,
|
||||||
|
"moduleDetection": "force",
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
|
||||||
|
/* Linting */
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": true,
|
||||||
|
"noUnusedParameters": true,
|
||||||
|
"erasableSyntaxOnly": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"noUncheckedSideEffectImports": true
|
||||||
|
},
|
||||||
|
"include": ["src"]
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"files": [],
|
||||||
|
"references": [
|
||||||
|
{ "path": "./tsconfig.app.json" },
|
||||||
|
{ "path": "./tsconfig.node.json" }
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
||||||
|
"target": "ES2022",
|
||||||
|
"lib": ["ES2023"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
|
||||||
|
/* Bundler mode */
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"verbatimModuleSyntax": true,
|
||||||
|
"moduleDetection": "force",
|
||||||
|
"noEmit": true,
|
||||||
|
|
||||||
|
/* Linting */
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": true,
|
||||||
|
"noUnusedParameters": true,
|
||||||
|
"erasableSyntaxOnly": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"noUncheckedSideEffectImports": true
|
||||||
|
},
|
||||||
|
"include": ["vite.config.ts"]
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
import { defineConfig } from "vite";
|
||||||
|
import react from "@vitejs/plugin-react";
|
||||||
|
import tailwindcss from "@tailwindcss/vite";
|
||||||
|
|
||||||
|
// https://vite.dev/config/
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react(), tailwindcss()],
|
||||||
|
base: "/ui/",
|
||||||
|
build: {
|
||||||
|
outDir: "../proxy/ui_dist",
|
||||||
|
assetsDir: "assets",
|
||||||
|
},
|
||||||
|
server: {
|
||||||
|
proxy: {
|
||||||
|
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
||||||
|
"/logs": "http://localhost:8080",
|
||||||
|
"/upstream": "http://localhost:8080",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||