Compare commits
109 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 539278343b | |||
| 00b738cd0f | |||
| 70930e4e91 | |||
| 1f6179110c | |||
| 216c40b951 | |||
| 9e3d491c85 | |||
| 1a84926505 | |||
| fc3bb716df | |||
| c36986fef6 | |||
| 558801db1a | |||
| b21dee27c1 | |||
| f58c8c8ec5 | |||
| 954e2dee73 | |||
| a533aec736 | |||
| 97b17fc47d | |||
| 2457840698 | |||
| 7f55494151 | |||
| 831a90d3b0 | |||
| 977f1856bb | |||
| 52b329f7bc | |||
| 57803fd3aa | |||
| c55d0cc842 | |||
| 7acbaf4712 | |||
| fcc5ad135a | |||
| 305e5a0031 | |||
| 04fc67354a | |||
| 4662cf7699 | |||
| 5dc6b3e6d9 | |||
| 74c69f39ef | |||
| a186318892 | |||
| c4e4d5e1e9 | |||
| 7985e94ba4 | |||
| 74556c3a36 | |||
| 5c381e4b30 | |||
| 10569ed546 | |||
| 5b10b3c23f | |||
| 45ea792a3a | |||
| 1bc2802353 | |||
| 701476c0c4 | |||
| 5c63e0066c | |||
| 8be5073c51 | |||
| 6307bd3205 | |||
| 558a72de17 | |||
| dc42cf366d | |||
| ba0a81937a | |||
| 574fdfabb4 | |||
| 5172cb2e12 | |||
| 5672cb03fd | |||
| 0f583163f7 | |||
| 7905fa9ea3 | |||
| bbaf172956 | |||
| fd50932dbc | |||
| 8c693e7fcf | |||
| 8f2af26a41 | |||
| 01d4838fb3 | |||
| accd65294b | |||
| 7472a25864 | |||
| cce0bc6aa1 | |||
| 36e25125e8 | |||
| 9a54273d15 | |||
| 87dce5f8f6 | |||
| 307e619521 | |||
| 6299c1b874 | |||
| a906cd459b | |||
| 78b2bc3dbc | |||
| 6a058e4191 | |||
| 1921e570d7 | |||
| c867a6c9a2 | |||
| 3bd1b23ce0 | |||
| 10606abf89 | |||
| fefd14903d | |||
| 717d64e336 | |||
| 285191e655 | |||
| 4236cec03a | |||
| 756193d0dd | |||
| a6b2e930d8 | |||
| 9e02c22ff8 | |||
| 0bdbf2fdc1 | |||
| 49035e2e8e | |||
| 9963ae18bf | |||
| 2ae48c713b | |||
| 54c519e365 | |||
| 3fce9ee0e9 | |||
| 5899ae7966 | |||
| 591a9cdf4d | |||
| 9a3c656738 | |||
| 75015f82ea | |||
| cc33b6c270 | |||
| 4fa12a429c | |||
| 2dc0ca0663 | |||
| a84098d3b4 | |||
| 4d02ccd26a | |||
| dfd47eeac4 | |||
| 1ac6499c08 | |||
| 25f3dc25e7 | |||
| 8422e4e6a1 | |||
| 02ee29d881 | |||
| b2a891f8f4 | |||
| 8d2b568897 | |||
| fb44cf4e08 | |||
| 02aee4e86d | |||
| f45896d395 | |||
| f7e46a359f | |||
| c260907415 | |||
| b83a5fa291 | |||
| 6e2ff28d59 | |||
| a8b81f2799 | |||
| f9ee7156dc | |||
| 2d00120781 |
@@ -1,11 +1,13 @@
|
|||||||
---
|
---
|
||||||
name: Bug Report
|
name: Bug Report
|
||||||
about: Something is not working as expected...
|
about: I found a defect
|
||||||
title: ''
|
title: ''
|
||||||
labels: bug
|
labels: 'unconfirmed bug'
|
||||||
assignees: ''
|
assignees: ''
|
||||||
|
|
||||||
---
|
---
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> If you have questions about llama-swap please post in the Q&A in Discussions. Use bug reports when you've found a defect and wish to discuss a fix.
|
||||||
|
|
||||||
**Describe the bug**
|
**Describe the bug**
|
||||||
A clear and concise description of what the bug is.
|
A clear and concise description of what the bug is.
|
||||||
|
|||||||
@@ -22,6 +22,13 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
go-version: '1.23'
|
go-version: '1.23'
|
||||||
|
|
||||||
|
# Only run in this linux based runner
|
||||||
|
- name: Check Formatting
|
||||||
|
run: |
|
||||||
|
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
|
||||||
|
gofmt -l . | grep -v 'event/.*_test.go'
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
# cache simple-responder to save the build time
|
# cache simple-responder to save the build time
|
||||||
- name: Restore Simple Responder
|
- name: Restore Simple Responder
|
||||||
id: restore-simple-responder
|
id: restore-simple-responder
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ on:
|
|||||||
|
|
||||||
# Allows manual triggering of the workflow
|
# Allows manual triggering of the workflow
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: 'Tag version to release (e.g. v144)'
|
||||||
|
required: true
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
@@ -20,9 +24,22 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||||
-
|
-
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
|
-
|
||||||
|
name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '23'
|
||||||
|
-
|
||||||
|
name: Install dependencies and build UI
|
||||||
|
run: |
|
||||||
|
cd ui
|
||||||
|
npm ci
|
||||||
|
npm run build
|
||||||
|
|
||||||
-
|
-
|
||||||
name: Run GoReleaser
|
name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v6
|
uses: goreleaser/goreleaser-action@v6
|
||||||
@@ -33,4 +50,30 @@ jobs:
|
|||||||
version: '~> v2'
|
version: '~> v2'
|
||||||
args: release --clean
|
args: release --clean
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
trigger-tap-update:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: goreleaser
|
||||||
|
steps:
|
||||||
|
- name: "Resolve tag to dispatch"
|
||||||
|
id: tag
|
||||||
|
run: |
|
||||||
|
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||||
|
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: "Trigger tap repository update"
|
||||||
|
uses: peter-evans/repository-dispatch@v2
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.TAP_REPO_PAT }}
|
||||||
|
repository: mostlygeek/homebrew-llama-swap
|
||||||
|
event-type: new-release
|
||||||
|
client-payload: |
|
||||||
|
{
|
||||||
|
"release": {
|
||||||
|
"tag_name": "${{ steps.tag.outputs.tag }}"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,3 +4,4 @@ build/
|
|||||||
dist/
|
dist/
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
.dev/
|
||||||
|
|||||||
@@ -17,14 +17,16 @@ builds:
|
|||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
|
|
||||||
# use zip format for windows
|
|
||||||
archives:
|
archives:
|
||||||
- id: default
|
- id: default
|
||||||
format: tar.gz
|
formats:
|
||||||
|
- tar.gz
|
||||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
builds_info:
|
builds_info:
|
||||||
group: root
|
group: root
|
||||||
owner: root
|
owner: root
|
||||||
format_overrides:
|
format_overrides:
|
||||||
|
# use zip format for windows
|
||||||
- goos: windows
|
- goos: windows
|
||||||
format: zip
|
formats:
|
||||||
|
- zip
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
# Project: llama-swap
|
||||||
|
|
||||||
|
## Project Description:
|
||||||
|
|
||||||
|
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||||
|
|
||||||
|
## Tech stack
|
||||||
|
|
||||||
|
- golang
|
||||||
|
- typescript, vite and react for UI (ui/)
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors.
|
||||||
|
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
||||||
|
|
||||||
|
## Workflow Tasks
|
||||||
|
|
||||||
|
### Plan Improvements
|
||||||
|
|
||||||
|
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
||||||
|
|
||||||
|
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
||||||
|
|
||||||
|
- Identify any inconsistencies.
|
||||||
|
- Expand plans out to be detailed specification of requirements and changes to be made.
|
||||||
|
- Plans should have at least these sections:
|
||||||
|
- Title - very short, describes changes
|
||||||
|
- Overview: A more detailed summary of goal and outcomes desired
|
||||||
|
- Design Requirements: Detailed descriptions of what needs to be done
|
||||||
|
- Testing Plan: Tests to be implemented
|
||||||
|
- Checklist: A detailed list of changes to be made
|
||||||
|
|
||||||
|
Look for "plan expansion" as explicit instructions to improve a plan.
|
||||||
|
|
||||||
|
### Implementation of plans
|
||||||
|
|
||||||
|
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
||||||
|
|
||||||
|
## General Rules
|
||||||
|
|
||||||
|
- when summarizing changes only include details that require further action (action items)
|
||||||
|
- when there are no action items, just say "Done."
|
||||||
@@ -19,24 +19,42 @@ all: mac linux simple-responder
|
|||||||
clean:
|
clean:
|
||||||
rm -rf $(BUILD_DIR)
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
test:
|
proxy/ui_dist/placeholder.txt:
|
||||||
go test -short -v -count=1 ./proxy
|
mkdir -p proxy/ui_dist
|
||||||
|
touch $@
|
||||||
|
|
||||||
test-all:
|
# use cached test results while developing
|
||||||
go test -v -count=1 ./proxy
|
test-dev: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -short ./proxy/...
|
||||||
|
staticcheck ./proxy/... || true
|
||||||
|
|
||||||
|
test: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -short -count=1 ./proxy/...
|
||||||
|
|
||||||
|
# for CI - full test (takes longer)
|
||||||
|
test-all: proxy/ui_dist/placeholder.txt
|
||||||
|
go test -count=1 ./proxy/...
|
||||||
|
|
||||||
|
ui/node_modules:
|
||||||
|
cd ui && npm install
|
||||||
|
|
||||||
|
# build react UI
|
||||||
|
ui: ui/node_modules
|
||||||
|
cd ui && npm run build
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
mac:
|
mac: ui
|
||||||
@echo "Building Mac binary..."
|
@echo "Building Mac binary..."
|
||||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||||
|
|
||||||
# Build Linux binary
|
# Build Linux binary
|
||||||
linux:
|
linux: ui
|
||||||
@echo "Building Linux binary..."
|
@echo "Building Linux binary..."
|
||||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||||
|
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||||
|
|
||||||
# Build Windows binary
|
# Build Windows binary
|
||||||
windows:
|
windows: ui
|
||||||
@echo "Building Windows binary..."
|
@echo "Building Windows binary..."
|
||||||
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe
|
||||||
|
|
||||||
@@ -69,4 +87,4 @@ release:
|
|||||||
git tag "$$new_tag";
|
git tag "$$new_tag";
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean mac linux windows simple-responder
|
.PHONY: all clean ui mac linux windows simple-responder test test-all test-dev
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||||
|
|
||||||
Written in golang, it is very easy to install (single binary with no dependancies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images.
|
Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary, a provided docker images or Homebrew.
|
||||||
|
|
||||||
## Features:
|
## Features:
|
||||||
|
|
||||||
@@ -18,188 +18,112 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
|||||||
- `v1/completions`
|
- `v1/completions`
|
||||||
- `v1/chat/completions`
|
- `v1/chat/completions`
|
||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
- `v1/rerank`
|
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||||
|
- ✅ llama-server (llama.cpp) supported endpoints:
|
||||||
|
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||||
|
- `/infill` - for code infilling
|
||||||
|
- `/completion` - for completion endpoint
|
||||||
- ✅ llama-swap custom API endpoints
|
- ✅ llama-swap custom API endpoints
|
||||||
|
- `/ui` - web UI
|
||||||
- `/log` - remote log monitoring
|
- `/log` - remote log monitoring
|
||||||
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
|
||||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||||
|
- `/health` - just returns "OK"
|
||||||
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||||
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
- ✅ Automatic unloading of models after timeout by setting a `ttl`
|
||||||
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
|
||||||
- ✅ Docker and Podman support
|
- ✅ Reliable Docker and Podman support using `cmd` and `cmdStop` together
|
||||||
- ✅ Full control over server settings per model
|
- ✅ Full control over server settings per model
|
||||||
|
- ✅ Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
|
||||||
|
|
||||||
## How does llama-swap work?
|
## How does llama-swap work?
|
||||||
|
|
||||||
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
|
||||||
|
|
||||||
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
llama-swap's configuration is purposefully simple.
|
llama-swap is managed entirely through a yaml configuration file.
|
||||||
|
|
||||||
|
It can be very minimal to start:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
models:
|
models:
|
||||||
"qwen2.5":
|
"qwen2.5":
|
||||||
proxy: "http://127.0.0.1:9999"
|
cmd: |
|
||||||
cmd: >
|
/path/to/llama-server
|
||||||
/app/llama-server
|
|
||||||
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
--port 9999
|
--port ${PORT}
|
||||||
|
|
||||||
"smollm2":
|
|
||||||
proxy: "http://127.0.0.1:9999"
|
|
||||||
cmd: >
|
|
||||||
/app/llama-server
|
|
||||||
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
|
||||||
--port 9999
|
|
||||||
```
|
```
|
||||||
|
|
||||||
<details>
|
However, there are many more capabilities that llama-swap supports:
|
||||||
<summary>But also very powerful ...</summary>
|
|
||||||
|
|
||||||
```yaml
|
- `groups` to run multiple models at once
|
||||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
- `ttl` to automatically unload models
|
||||||
# Default (and minimum) is 15 seconds
|
- `macros` for reusable snippets
|
||||||
healthCheckTimeout: 60
|
- `aliases` to use familiar model names (e.g., "gpt-4o-mini")
|
||||||
|
- `env` to pass custom environment variables to inference servers
|
||||||
|
- `cmdStop` for to gracefully stop Docker/Podman containers
|
||||||
|
- `useModelName` to override model names sent to upstream servers
|
||||||
|
- `healthCheckTimeout` to control model startup wait times
|
||||||
|
- `${PORT}` automatic port variables for dynamic port assignment
|
||||||
|
|
||||||
# Valid log levels: debug, info (default), warn, error
|
See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki all options and examples.
|
||||||
logLevel: info
|
|
||||||
|
|
||||||
# Automatic Port Values
|
## Reverse Proxy Configuration (nginx)
|
||||||
# use ${PORT} in model.cmd and model.proxy to use an automatic port number
|
|
||||||
# when you use ${PORT} you can omit a custom model.proxy value, as it will
|
|
||||||
# default to http://localhost:${PORT}
|
|
||||||
|
|
||||||
# override the default port (5800) for automatic port values
|
If you deploy llama-swap behind nginx, disable response buffering for streaming endpoints. By default, nginx buffers responses which breaks Server‑Sent Events (SSE) and streaming chat completion. ([#236](https://github.com/mostlygeek/llama-swap/issues/236))
|
||||||
startPort: 10001
|
|
||||||
|
|
||||||
# define valid model values and the upstream server start
|
Recommended nginx configuration snippets:
|
||||||
models:
|
|
||||||
"llama":
|
|
||||||
# multiline for readability
|
|
||||||
cmd: >
|
|
||||||
llama-server --port 8999
|
|
||||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
|
||||||
|
|
||||||
# environment variables to pass to the command
|
```nginx
|
||||||
env:
|
# SSE for UI events/logs
|
||||||
- "CUDA_VISIBLE_DEVICES=0"
|
location /api/events {
|
||||||
|
proxy_pass http://your-llama-swap-backend;
|
||||||
|
proxy_buffering off;
|
||||||
|
proxy_cache off;
|
||||||
|
}
|
||||||
|
|
||||||
# where to reach the server started by cmd, make sure the ports match
|
# Streaming chat completions (stream=true)
|
||||||
# can be omitted if you use an automatic ${PORT} in cmd
|
location /v1/chat/completions {
|
||||||
proxy: http://127.0.0.1:8999
|
proxy_pass http://your-llama-swap-backend;
|
||||||
|
proxy_buffering off;
|
||||||
# aliases names to use this model for
|
proxy_cache off;
|
||||||
aliases:
|
}
|
||||||
- "gpt-4o-mini"
|
|
||||||
- "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
# check this path for an HTTP 200 OK before serving requests
|
|
||||||
# default: /health to match llama.cpp
|
|
||||||
# use "none" to skip endpoint checking, but may cause HTTP errors
|
|
||||||
# until the model is ready
|
|
||||||
checkEndpoint: /custom-endpoint
|
|
||||||
|
|
||||||
# automatically unload the model after this many seconds
|
|
||||||
# ttl values must be a value greater than 0
|
|
||||||
# default: 0 = never unload model
|
|
||||||
ttl: 60
|
|
||||||
|
|
||||||
# `useModelName` overrides the model name in the request
|
|
||||||
# and sends a specific name to the upstream server
|
|
||||||
useModelName: "qwen:qwq"
|
|
||||||
|
|
||||||
# unlisted models do not show up in /v1/models or /upstream lists
|
|
||||||
# but they can still be requested as normal
|
|
||||||
"qwen-unlisted":
|
|
||||||
unlisted: true
|
|
||||||
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
|
||||||
"docker-llama":
|
|
||||||
proxy: "http://127.0.0.1:${PORT}"
|
|
||||||
cmd: >
|
|
||||||
docker run --name dockertest
|
|
||||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
|
||||||
|
|
||||||
# Groups provide advanced controls over model swapping behaviour. Using groups
|
|
||||||
# some models can be kept loaded indefinitely, while others are swapped out.
|
|
||||||
#
|
|
||||||
# Tips:
|
|
||||||
#
|
|
||||||
# - models must be defined above in the Models section
|
|
||||||
# - a model can only be a member of one group
|
|
||||||
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
|
||||||
# - see issue #109 for details
|
|
||||||
#
|
|
||||||
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
|
||||||
groups:
|
|
||||||
# group1 is the default behaviour of llama-swap where only one model is allowed
|
|
||||||
# to run a time across the whole llama-swap instance
|
|
||||||
"group1":
|
|
||||||
# swap controls the model swapping behaviour in within the group
|
|
||||||
# - true : only one model is allowed to run at a time
|
|
||||||
# - false: all models can run together, no swapping
|
|
||||||
swap: true
|
|
||||||
|
|
||||||
# exclusive controls how the group affects other groups
|
|
||||||
# - true: causes all other groups to unload their models when this group runs a model
|
|
||||||
# - false: does not affect other groups
|
|
||||||
exclusive: true
|
|
||||||
|
|
||||||
# members references the models defined above
|
|
||||||
members:
|
|
||||||
- "llama"
|
|
||||||
- "qwen-unlisted"
|
|
||||||
|
|
||||||
# models in this group are never unloaded
|
|
||||||
"group2":
|
|
||||||
swap: false
|
|
||||||
exclusive: false
|
|
||||||
members:
|
|
||||||
- "docker-llama"
|
|
||||||
# (not defined above, here for example)
|
|
||||||
- "modelA"
|
|
||||||
- "modelB"
|
|
||||||
|
|
||||||
"forever":
|
|
||||||
# setting persistent to true causes the group to never be affected by the swapping behaviour of
|
|
||||||
# other groups. It is a shortcut to keeping some models always loaded.
|
|
||||||
persistent: true
|
|
||||||
|
|
||||||
# set swap/exclusive to false to prevent swapping inside the group and effect on other groups
|
|
||||||
swap: false
|
|
||||||
exclusive: false
|
|
||||||
members:
|
|
||||||
- "forever-modelA"
|
|
||||||
- "forever-modelB"
|
|
||||||
- "forever-modelc"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use Case Examples
|
As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. However, explicitly disabling `proxy_buffering` at your reverse proxy is still recommended for reliable streaming behavior.
|
||||||
|
|
||||||
- [config.example.yaml](config.example.yaml) includes example for supporting `v1/embeddings` and `v1/rerank` endpoints
|
## Web UI
|
||||||
- [Speculative Decoding](examples/speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases.
|
|
||||||
- [Optimizing Code Generation](examples/benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest.
|
|
||||||
- [Restart on Config Change](examples/restart-on-config-change/README.md) - automatically restart llama-swap when trying out different configurations.
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
llama-swap includes a real time web interface for monitoring logs and models:
|
||||||
|
|
||||||
Docker is the quickest way to try out llama-swap:
|
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/adef4a8e-de0b-49db-885a-8f6dedae6799" />
|
||||||
|
|
||||||
|
The Activity Page shows recent requests:
|
||||||
|
|
||||||
|
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
llama-swap can be installed in multiple ways
|
||||||
|
|
||||||
|
1. Docker
|
||||||
|
2. Homebrew (OSX and Linux)
|
||||||
|
3. From release binaries
|
||||||
|
4. From source
|
||||||
|
|
||||||
|
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
|
||||||
|
|
||||||
|
Docker images with llama-swap and llama-server are built nightly.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# use CPU inference
|
# use CPU inference comes with the example config above
|
||||||
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
|
||||||
|
|
||||||
|
|
||||||
# qwen2.5 0.5B
|
# qwen2.5 0.5B
|
||||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
@@ -207,7 +131,6 @@ $ curl -s http://localhost:9292/v1/chat/completions \
|
|||||||
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
-d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \
|
||||||
jq -r '.choices[0].message.content'
|
jq -r '.choices[0].message.content'
|
||||||
|
|
||||||
|
|
||||||
# SmolLM2 135M
|
# SmolLM2 135M
|
||||||
$ curl -s http://localhost:9292/v1/chat/completions \
|
$ curl -s http://localhost:9292/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
@@ -217,7 +140,7 @@ $ curl -s http://localhost:9292/v1/chat/completions \
|
|||||||
```
|
```
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Docker images are nightly ...</summary>
|
<summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
|
||||||
|
|
||||||
They include:
|
They include:
|
||||||
|
|
||||||
@@ -240,31 +163,45 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases))
|
### Homebrew Install (macOS/Linux)
|
||||||
|
|
||||||
Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server.
|
The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Set up tap and install formula
|
||||||
|
brew tap mostlygeek/llama-swap
|
||||||
|
brew install llama-swap
|
||||||
|
# Run llama-swap
|
||||||
|
llama-swap --config path/to/config.yaml --listen localhost:8080
|
||||||
|
```
|
||||||
|
|
||||||
|
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
|
||||||
|
|
||||||
|
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
|
||||||
|
|
||||||
|
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
|
||||||
|
|
||||||
1. Create a configuration file, see [config.example.yaml](config.example.yaml)
|
|
||||||
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
|
||||||
1. Run the binary with `llama-swap --config path/to/config.yaml`.
|
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
|
||||||
Available flags:
|
1. Run the binary with `llama-swap --config path/to/config.yaml --listen localhost:8080`.
|
||||||
- `--config`: Path to the configuration file (default: `config.yaml`).
|
Available flags:
|
||||||
- `--listen`: Address and port to listen on (default: `:8080`).
|
- `--config`: Path to the configuration file (default: `config.yaml`).
|
||||||
- `--version`: Show version information and exit.
|
- `--listen`: Address and port to listen on (default: `:8080`).
|
||||||
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
|
- `--version`: Show version information and exit.
|
||||||
|
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
|
||||||
|
|
||||||
### Building from source
|
### Building from source
|
||||||
|
|
||||||
1. Install golang for your system
|
1. Build requires golang and nodejs for the user interface.
|
||||||
1. `git clone git@github.com:mostlygeek/llama-swap.git`
|
1. `git clone https://github.com/mostlygeek/llama-swap.git`
|
||||||
1. `make clean all`
|
1. `make clean all`
|
||||||
1. Binaries will be in `build/` subdirectory
|
1. Binaries will be in `build/` subdirectory
|
||||||
|
|
||||||
## Monitoring Logs
|
## Monitoring Logs
|
||||||
|
|
||||||
Open the `http://<host>/logs` with your browser to get a web interface with streaming logs.
|
Open the `http://<host>:<port>/` with your browser to get a web interface with streaming logs.
|
||||||
|
|
||||||
Of course, CLI access is also supported:
|
CLI access is also supported:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# sends up to the last 10KB of logs
|
# sends up to the last 10KB of logs
|
||||||
@@ -292,32 +229,9 @@ Any OpenAI compatible server would work. llama-swap was originally designed for
|
|||||||
|
|
||||||
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown.
|
||||||
|
|
||||||
## Systemd Unit Files
|
|
||||||
|
|
||||||
Use this unit file to start llama-swap on boot. This is only tested on Ubuntu.
|
|
||||||
|
|
||||||
`/etc/systemd/system/llama-swap.service`
|
|
||||||
|
|
||||||
```
|
|
||||||
[Unit]
|
|
||||||
Description=llama-swap
|
|
||||||
After=network.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
User=nobody
|
|
||||||
|
|
||||||
# set this to match your environment
|
|
||||||
ExecStart=/path/to/llama-swap --config /path/to/llama-swap.config.yml
|
|
||||||
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=3
|
|
||||||
StartLimitBurst=3
|
|
||||||
StartLimitInterval=30
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
|
||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> ⭐️ Star this project to help others discover it!
|
||||||
|
|
||||||
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
[](https://www.star-history.com/#mostlygeek/llama-swap&Date)
|
||||||
|
|||||||
@@ -0,0 +1,292 @@
|
|||||||
|
# Add Model Metadata Support with Typed Macros
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Implement support for arbitrary metadata on model configurations that can be exposed through the `/v1/models` API endpoint. This feature extends the existing macro system to support scalar types (string, int, float, bool) instead of only strings, enabling type-safe metadata values.
|
||||||
|
|
||||||
|
The metadata will be schemaless, allowing users to define any key-value pairs they need. Macro substitution will work within metadata values, preserving types when macros are used directly and converting to strings when macros are interpolated within strings.
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### 1. Enhanced Macro System
|
||||||
|
|
||||||
|
**Current State:**
|
||||||
|
|
||||||
|
- Macros are defined as `map[string]string` at both global and model levels
|
||||||
|
- Only string substitution is supported
|
||||||
|
- Macros are replaced in: `cmd`, `cmdStop`, `proxy`, `checkEndpoint`, `filters.stripParams`
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Change `MacroList` type from `map[string]string` to `map[string]any`
|
||||||
|
- Support scalar types: `string`, `int`, `float64`, `bool`
|
||||||
|
- Implement type-preserving macro substitution:
|
||||||
|
- Direct macro usage (`key: ${macro}`) preserves the macro's type
|
||||||
|
- Interpolated usage (`key: "text ${macro}"`) converts to string
|
||||||
|
- Add validation to ensure macro values are scalar types only
|
||||||
|
- Update existing macro substitution logic in [proxy/config/config.go](proxy/config/config.go) to handle `any` types
|
||||||
|
|
||||||
|
**Implementation Details:**
|
||||||
|
|
||||||
|
- Create a generic helper function to perform macro substitution that:
|
||||||
|
- Takes a value of type `any`
|
||||||
|
- Recursively processes maps, slices, and scalar values
|
||||||
|
- Replaces `${macro_name}` patterns with macro values
|
||||||
|
- Preserves types for direct substitution
|
||||||
|
- Converts to strings for interpolated substitution
|
||||||
|
- Update `validateMacro()` function to accept `any` type and validate scalar types
|
||||||
|
- Maintain backward compatibility with existing string-only macros
|
||||||
|
|
||||||
|
### 2. Metadata Field in ModelConfig
|
||||||
|
|
||||||
|
**Location:** [proxy/config/model_config.go](proxy/config/model_config.go)
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Add `Metadata map[string]any` field to `ModelConfig` struct
|
||||||
|
- Support YAML unmarshaling of arbitrary structures (maps, arrays, scalars)
|
||||||
|
- Apply macro substitution to metadata values during config loading
|
||||||
|
|
||||||
|
**Schema Requirements:**
|
||||||
|
|
||||||
|
- Metadata is optional (default: empty/nil map)
|
||||||
|
- Supports nested structures (objects within objects, arrays, etc.)
|
||||||
|
- All string values within metadata undergo macro substitution
|
||||||
|
- Type preservation rules apply as described above
|
||||||
|
|
||||||
|
### 3. Macro Substitution in Metadata
|
||||||
|
|
||||||
|
**Location:** [proxy/config/config.go](proxy/config/config.go) in `LoadConfigFromReader()`
|
||||||
|
|
||||||
|
**Process Flow:**
|
||||||
|
|
||||||
|
1. After loading YAML configuration
|
||||||
|
2. After model-level and global macro merging
|
||||||
|
3. Apply macro substitution to `ModelConfig.Metadata` field
|
||||||
|
4. Use the same merged macros available to `cmd`, `proxy`, etc.
|
||||||
|
5. Process recursively through all nested structures
|
||||||
|
|
||||||
|
**Substitution Rules:**
|
||||||
|
|
||||||
|
- `port: ${PORT}` → keeps integer type from PORT macro
|
||||||
|
- `temperature: ${temp}` → keeps float type from temp macro
|
||||||
|
- `note: "Running on ${PORT}"` → converts to string `"Running on 10001"`
|
||||||
|
- Arrays and nested objects are processed recursively
|
||||||
|
- Unknown macros should cause configuration load error (consistent with existing behavior)
|
||||||
|
|
||||||
|
### 4. API Response Updates
|
||||||
|
|
||||||
|
**Location:** [proxy/proxymanager.go:350](proxy/proxymanager.go#L350) `listModelsHandler()`
|
||||||
|
|
||||||
|
**Current Behavior:**
|
||||||
|
|
||||||
|
- Returns model records with: `id`, `object`, `created`, `owned_by`
|
||||||
|
- Optionally includes: `name`, `description`
|
||||||
|
|
||||||
|
**Required Changes:**
|
||||||
|
|
||||||
|
- Add metadata to each model record under the key `llamaswap_meta`
|
||||||
|
- Only include `llamaswap_meta` if metadata is non-empty
|
||||||
|
- Preserve all types when marshaling to JSON
|
||||||
|
- Maintain existing sorting by model ID
|
||||||
|
|
||||||
|
**Example Response:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "llama",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1234567890,
|
||||||
|
"owned_by": "llama-swap",
|
||||||
|
"name": "llama 3.1 8B",
|
||||||
|
"description": "A small but capable model",
|
||||||
|
"llamaswap_meta": {
|
||||||
|
"port": 10001,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"note": "The llama is running on port 10001 temp=0.7, context=16384",
|
||||||
|
"a_list": [1, 1.23, "macros are OK in list and dictionary types: llama"],
|
||||||
|
"an_obj": {
|
||||||
|
"a": "1",
|
||||||
|
"b": 2,
|
||||||
|
"c": [0.7, false, "model: llama"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Validation and Error Handling
|
||||||
|
|
||||||
|
**Macro Validation:**
|
||||||
|
|
||||||
|
- Extend `validateMacro()` to accept values of type `any`
|
||||||
|
- Verify macro values are scalar types: `string`, `int`, `float64`, `bool`
|
||||||
|
- Reject complex types (maps, slices, structs) as macro values
|
||||||
|
- Maintain existing validation for macro names and lengths
|
||||||
|
|
||||||
|
**Configuration Loading:**
|
||||||
|
|
||||||
|
- Fail fast if unknown macros are found in metadata
|
||||||
|
- Provide clear error messages indicating which model and field contains errors
|
||||||
|
- Ensure macros in metadata follow same rules as macros in cmd/proxy fields
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
### Test 1: Model-Level Macros with Different Types
|
||||||
|
|
||||||
|
**File:** [proxy/config/model_config_test.go](proxy/config/model_config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Define model with macros of each scalar type
|
||||||
|
- Verify metadata correctly substitutes and preserves types
|
||||||
|
- Test direct substitution (`port: ${PORT}`)
|
||||||
|
- Test string interpolation (`note: "Port is ${PORT}"`)
|
||||||
|
- Verify nested objects and arrays work correctly
|
||||||
|
|
||||||
|
### Test 2: Global and Model Macro Precedence
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Define same macro at global and model level with different types
|
||||||
|
- Verify model-level macro takes precedence
|
||||||
|
- Test metadata uses correct macro value
|
||||||
|
- Verify type is preserved from the winning macro
|
||||||
|
|
||||||
|
### Test 3: Macro Validation
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Test that complex types (maps, arrays) are rejected as macro values
|
||||||
|
- Verify error message includes: macro name and type that was rejected
|
||||||
|
- Test that scalar types (string, int, float, bool) are accepted
|
||||||
|
- Each type should load without error
|
||||||
|
- Test macro name validation still works with `any` types
|
||||||
|
- Invalid characters, reserved names, length limits should still be enforced
|
||||||
|
|
||||||
|
### Test 4: Metadata in API Response
|
||||||
|
|
||||||
|
**File:** [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||||
|
|
||||||
|
**Existing Test:** `TestProxyManager_ListModelsHandler`
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Model with metadata → verify `llamaswap_meta` key appears
|
||||||
|
- Model without metadata → verify `llamaswap_meta` key is absent
|
||||||
|
- Verify all types are correctly marshaled to JSON
|
||||||
|
- Verify nested structures are preserved
|
||||||
|
- Verify macro substitution has occurred before serialization
|
||||||
|
|
||||||
|
### Test 5: Unknown Macros in Metadata
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Use undefined macro in metadata
|
||||||
|
- Verify configuration loading fails with clear error
|
||||||
|
- Error should indicate model name and that macro is undefined
|
||||||
|
|
||||||
|
### Test 6: Recursive Substitution
|
||||||
|
|
||||||
|
**File:** [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
**Test Cases:**
|
||||||
|
|
||||||
|
- Metadata with deeply nested structures
|
||||||
|
- Arrays containing objects with macros
|
||||||
|
- Objects containing arrays with macros
|
||||||
|
- Mixed string interpolation and direct substitution at various nesting levels
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
### Configuration Schema Changes
|
||||||
|
|
||||||
|
- [x] Change `MacroList` type from `map[string]string` to `map[string]any` in [proxy/config/config.go:19](proxy/config/config.go#L19)
|
||||||
|
- [x] Add `Metadata map[string]any` field to `ModelConfig` struct in [proxy/config/model_config.go:37](proxy/config/model_config.go#L37)
|
||||||
|
- [x] Update `validateMacro()` function signature to accept `any` type for values
|
||||||
|
- [x] Add validation logic to ensure macro values are scalar types only
|
||||||
|
|
||||||
|
### Macro Substitution Logic
|
||||||
|
|
||||||
|
- [x] Create generic recursive function `substituteMetadataMacros()` to handle `any` types
|
||||||
|
- [x] Implement type-preserving direct substitution logic
|
||||||
|
- [x] Implement string interpolation with type conversion
|
||||||
|
- [x] Handle maps: recursively process all values
|
||||||
|
- [x] Handle slices: recursively process all elements
|
||||||
|
- [x] Handle scalar types: perform string-based macro substitution if value is string
|
||||||
|
- [x] Integrate macro substitution into `LoadConfigFromReader()` after existing macro expansion
|
||||||
|
- [x] Update existing macro substitution calls to use merged macros with correct types
|
||||||
|
|
||||||
|
### API Response Changes
|
||||||
|
|
||||||
|
- [x] Modify `listModelsHandler()` in [proxy/proxymanager.go:350](proxy/proxymanager.go#L350)
|
||||||
|
- [x] Add `llamaswap_meta` field to model records when metadata exists
|
||||||
|
- [x] Ensure empty metadata results in omitted `llamaswap_meta` key
|
||||||
|
- [x] Verify JSON marshaling preserves all types correctly
|
||||||
|
|
||||||
|
### Testing - Config Package
|
||||||
|
|
||||||
|
- [x] Add test for string macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for int macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for float macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for bool macros in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for string interpolation in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for model-level macro precedence: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for nested structures in metadata: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for unknown macro in metadata (should error): [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
- [x] Add test for invalid macro type validation: [proxy/config/config_test.go](proxy/config/config_test.go)
|
||||||
|
|
||||||
|
### Testing - Model Config Package
|
||||||
|
|
||||||
|
- [x] Add test cases to [proxy/config/model_config_test.go](proxy/config/model_config_test.go) for metadata unmarshaling
|
||||||
|
- [x] Test metadata with various scalar types
|
||||||
|
- [x] Test metadata with nested objects and arrays
|
||||||
|
|
||||||
|
### Testing - Proxy Manager
|
||||||
|
|
||||||
|
- [x] Update `TestProxyManager_ListModelsHandler` in [proxy/proxymanager_test.go](proxy/proxymanager_test.go)
|
||||||
|
- [x] Add test case for model with metadata
|
||||||
|
- [x] Add test case for model without metadata
|
||||||
|
- [x] Verify `llamaswap_meta` key presence/absence
|
||||||
|
- [x] Verify type preservation in JSON output
|
||||||
|
- [x] Verify macro substitution has occurred
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
- [x] Verify [config.example.yaml](config.example.yaml) already has complete metadata examples (lines 149-171)
|
||||||
|
- [x] No additional documentation needed per project instructions
|
||||||
|
|
||||||
|
## Known Issues and Considerations
|
||||||
|
|
||||||
|
### Inconsistencies
|
||||||
|
|
||||||
|
None identified. The plan references the correct existing example in [config.example.yaml:149-171](config.example.yaml#L149-L171).
|
||||||
|
|
||||||
|
### Design Decisions
|
||||||
|
|
||||||
|
1. **Why `llamaswap_meta` instead of merging into record?**
|
||||||
|
|
||||||
|
- Avoids potential collisions with OpenAI API standard fields
|
||||||
|
- Makes it clear this is llama-swap specific metadata
|
||||||
|
- Easier for clients to distinguish standard vs. custom fields
|
||||||
|
|
||||||
|
2. **Why support nested structures?**
|
||||||
|
|
||||||
|
- Provides maximum flexibility for users
|
||||||
|
- Aligns with the schemaless design principle
|
||||||
|
- Example config already demonstrates this capability
|
||||||
|
|
||||||
|
3. **Why validate macro types?**
|
||||||
|
- Prevents confusing behavior (e.g., substituting a map)
|
||||||
|
- Makes configuration errors explicit at load time
|
||||||
|
- Simpler implementation and testing
|
||||||
@@ -0,0 +1,397 @@
|
|||||||
|
# Improve macro-in-macro support
|
||||||
|
|
||||||
|
**Status: COMPLETED ✅**
|
||||||
|
|
||||||
|
## Title
|
||||||
|
|
||||||
|
Fix macro substitution ordering by preserving definition order using ordered YAML parsing
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The current macro implementation uses `map[string]any` which does not preserve insertion order. This causes issues when macros reference other macros - if macro `B` contains `${A}` but `B` is processed before `A`, the reference won't be substituted, leading to "unknown macro" errors.
|
||||||
|
|
||||||
|
**Goal:** Ensure macros are substituted in definition order (LIFO - last in, first out) to allow macros to reliably reference previously-defined macros.
|
||||||
|
|
||||||
|
**Outcomes:**
|
||||||
|
- Macros can reference other macros defined earlier in the config
|
||||||
|
- Macro substitution is deterministic and order-dependent
|
||||||
|
- Single-pass substitution prevents circular dependencies
|
||||||
|
- Use `yaml.Node` from `gopkg.in/yaml.v3` to preserve macro definition order
|
||||||
|
- All existing tests pass
|
||||||
|
- New tests validate substitution order and self-reference detection
|
||||||
|
|
||||||
|
## Design Requirements
|
||||||
|
|
||||||
|
### 1. YAML Parsing Strategy
|
||||||
|
- **Continue using:** `gopkg.in/yaml.v3` (current library)
|
||||||
|
- **Use:** `yaml.Node` for ordered parsing of macros
|
||||||
|
- **Reason:** `yaml.Node` preserves document structure and order, avoiding need for migration
|
||||||
|
|
||||||
|
### 2. Data Structure Changes
|
||||||
|
|
||||||
|
#### Current Implementation (config.go:19)
|
||||||
|
```go
|
||||||
|
type MacroList map[string]any
|
||||||
|
```
|
||||||
|
|
||||||
|
#### New Implementation
|
||||||
|
```go
|
||||||
|
type MacroList []MacroEntry
|
||||||
|
|
||||||
|
type MacroEntry struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Implementation Note:** Parse macros using `yaml.Node` to extract key-value pairs in document order, then construct the ordered `MacroList`.
|
||||||
|
|
||||||
|
### 3. Macro Substitution Order Rules
|
||||||
|
|
||||||
|
The substitution must follow this hierarchy (from most specific to least):
|
||||||
|
|
||||||
|
1. **Reserved macros** (last): `PORT`, `MODEL_ID` - substituted last, highest priority
|
||||||
|
2. **Model-level macros** (middle): Defined in specific model config, overrides global
|
||||||
|
3. **Global macros** (first): Defined at config root level
|
||||||
|
|
||||||
|
Within each level, macros are substituted in **reverse definition order** (LIFO):
|
||||||
|
- The last macro defined is substituted first
|
||||||
|
- This allows later macros to reference earlier ones
|
||||||
|
- Single-pass substitution prevents circular dependencies
|
||||||
|
|
||||||
|
### 4. Macro Reference Rules
|
||||||
|
|
||||||
|
**Allowed:**
|
||||||
|
- Macro can reference any macro defined **before** it (earlier in the file)
|
||||||
|
- Model macros can reference global macros
|
||||||
|
- Macros can reference reserved macros (`${PORT}`, `${MODEL_ID}`)
|
||||||
|
|
||||||
|
**Prohibited:**
|
||||||
|
- Macro cannot reference itself (e.g., `foo: "value ${foo}"`)
|
||||||
|
- Macro cannot reference macros defined **after** it
|
||||||
|
- No circular references (prevented by single-pass, ordered substitution)
|
||||||
|
|
||||||
|
### 5. Validation Requirements
|
||||||
|
|
||||||
|
Add validation to detect:
|
||||||
|
- **Self-references:** Macro value contains reference to its own name
|
||||||
|
- **Unknown macros:** After substitution, any remaining `${...}` references
|
||||||
|
|
||||||
|
Error messages should be clear:
|
||||||
|
```
|
||||||
|
macro 'foo' contains self-reference
|
||||||
|
unknown macro '${bar}' in model.cmd
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Implementation Changes
|
||||||
|
|
||||||
|
#### Files to Modify
|
||||||
|
|
||||||
|
1. **[proxy/config/config.go](proxy/config/config.go)**
|
||||||
|
- Line 19: Change `MacroList` type definition
|
||||||
|
- Line 69: Update `Macros MacroList` field
|
||||||
|
- Line 153-157: Update macro validation loop to work with ordered structure
|
||||||
|
- Line 175-188: Update model-level macro validation
|
||||||
|
- Line 181-188: **NEW** Implement proper macro merging respecting order
|
||||||
|
- Line 193-202: **NEW** Implement ordered macro substitution in LIFO order
|
||||||
|
- Line 389-415: Update `validateMacro` to detect self-references
|
||||||
|
- Line 420-475: Update `substituteMetadataMacros` to accept ordered MacroList
|
||||||
|
|
||||||
|
2. **[proxy/config/model_config.go](proxy/config/model_config.go)**
|
||||||
|
- Line 33: Update `Macros MacroList` field type
|
||||||
|
|
||||||
|
3. **All test files**
|
||||||
|
- Update test fixtures to use ordered macro definitions
|
||||||
|
- Ensure tests specify macro order explicitly
|
||||||
|
|
||||||
|
#### Core Algorithm
|
||||||
|
|
||||||
|
Replace the macro substitution logic in [config.go:181-252](proxy/config/config.go#L181-L252) with:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Merge global config and model macros. Model macros take precedence
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+2)
|
||||||
|
|
||||||
|
// Add global macros first
|
||||||
|
for _, entry := range config.Macros {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add model macros (can override global)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
// Remove any existing global macro with same name
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry // Override
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add reserved MODEL_ID macro at the end
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
|
||||||
|
// Check if PORT macro is needed
|
||||||
|
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
|
||||||
|
// enforce ${PORT} used in both cmd and proxy
|
||||||
|
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add PORT macro to the end (highest priority)
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "PORT", Value: nextPort})
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single-pass substitution: Substitute all macros in LIFO order (last defined first)
|
||||||
|
// This allows later macros to reference earlier ones
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
// Substitute in command fields
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in metadata (recursive)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
modelConfig.Metadata, err = substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Add this new helper function to replace `substituteMetadataMacros`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. Self-Reference Detection
|
||||||
|
|
||||||
|
Add to `validateMacro` function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
// ... existing validation ...
|
||||||
|
|
||||||
|
// Check for self-reference
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(str, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Plan
|
||||||
|
|
||||||
|
### 1. Migration Tests
|
||||||
|
- **Test:** All existing macro tests still pass after YAML library migration
|
||||||
|
- **Files:** All `*_test.go` files with macro tests
|
||||||
|
|
||||||
|
### 2. Macro Order Tests
|
||||||
|
|
||||||
|
#### Test: Macro-in-macro substitution order
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"A": "value-A"
|
||||||
|
"B": "prefix-${A}-suffix"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "echo ${B}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"echo prefix-value-A-suffix"`
|
||||||
|
|
||||||
|
#### Test: LIFO substitution order
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"base": "/models"
|
||||||
|
"path": "${base}/llama"
|
||||||
|
"full": "${path}/model.gguf"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: "load ${full}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"load /models/llama/model.gguf"`
|
||||||
|
|
||||||
|
#### Test: Model macro overrides global
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"tag": "global"
|
||||||
|
"msg": "value-${tag}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
"tag": "model-level"
|
||||||
|
cmd: "echo ${msg}"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"echo value-model-level"` (model macro overrides global)
|
||||||
|
|
||||||
|
### 3. Reserved Macro Tests
|
||||||
|
|
||||||
|
#### Test: MODEL_ID substituted in macro
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: "${podman-llama} -m model.gguf"
|
||||||
|
```
|
||||||
|
**Expected:** `cmd` becomes `"podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf"`
|
||||||
|
|
||||||
|
### 4. Error Detection Tests
|
||||||
|
|
||||||
|
#### Test: Self-reference detection
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"recursive": "value-${recursive}"
|
||||||
|
```
|
||||||
|
**Expected:** Error: `macro 'recursive' contains self-reference`
|
||||||
|
|
||||||
|
#### Test: Undefined macro reference
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"A": "value-${UNDEFINED}"
|
||||||
|
```
|
||||||
|
**Expected:** Error: `unknown macro '${UNDEFINED}' found in macros.A` (or similar)
|
||||||
|
|
||||||
|
### 5. Regression Tests
|
||||||
|
- Run all existing macro tests: `TestConfig_MacroReplacement`, `TestConfig_MacroReservedNames`, etc.
|
||||||
|
- Ensure all pass without modification (except test fixtures if needed)
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
### Phase 1: Data Structure Changes
|
||||||
|
- [ ] Implement custom `UnmarshalYAML` method for `MacroList` that uses `yaml.Node`
|
||||||
|
- [ ] Define new ordered `MacroList` type as `[]MacroEntry`
|
||||||
|
- [ ] Update `MacroList` type definition in [config.go](proxy/config/config.go#L19)
|
||||||
|
- [ ] Update `Config.Macros` field type in [config.go](proxy/config/config.go#L69)
|
||||||
|
- [ ] Update `ModelConfig.Macros` field type in [model_config.go](proxy/config/model_config.go#L33)
|
||||||
|
- [ ] Implement helper functions:
|
||||||
|
- [ ] `func (ml MacroList) Get(name string) (any, bool)` - lookup by name
|
||||||
|
- [ ] `func (ml MacroList) Set(name string, value any) MacroList` - add/override entry
|
||||||
|
- [ ] `func (ml MacroList) ToMap() map[string]any` - convert to map if needed
|
||||||
|
|
||||||
|
### Phase 2: Macro Validation Updates
|
||||||
|
- [ ] Update macro validation loop at [config.go:153-157](proxy/config/config.go#L153-L157)
|
||||||
|
- [ ] Update model macro validation at [config.go:175-179](proxy/config/config.go#L175-L179)
|
||||||
|
- [ ] Add self-reference detection to `validateMacro` function [config.go:389](proxy/config/config.go#L389)
|
||||||
|
- [ ] Test self-reference detection with new test case
|
||||||
|
|
||||||
|
### Phase 3: Macro Substitution Algorithm
|
||||||
|
- [ ] Implement ordered macro merging (global → model → reserved) at [config.go:181-188](proxy/config/config.go#L181-L188)
|
||||||
|
- [ ] Implement single-pass LIFO substitution loop (reverse iteration) at [config.go:193-202](proxy/config/config.go#L193-L202)
|
||||||
|
- [ ] Substitute in all string fields (cmd, cmdStop, proxy, checkEndpoint, stripParams)
|
||||||
|
- [ ] Substitute in metadata within same loop
|
||||||
|
- [ ] Ensure `MODEL_ID` is added to merged macros before substitution
|
||||||
|
- [ ] Ensure `PORT` is added after port assignment (if needed)
|
||||||
|
- [ ] Replace `substituteMetadataMacros` with new `substituteMacroInValue` function that processes one macro at a time [config.go:420](proxy/config/config.go#L420)
|
||||||
|
- [ ] Remove old metadata substitution code that was separate from main loop [config.go:245-251](proxy/config/config.go#L245-L251)
|
||||||
|
|
||||||
|
### Phase 4: Testing
|
||||||
|
- [ ] Run `make test-dev` - fix any static checking errors
|
||||||
|
- [ ] Add test: macro-in-macro basic substitution
|
||||||
|
- [ ] Add test: LIFO substitution order with 3+ macro levels
|
||||||
|
- [ ] Add test: MODEL_ID in global macro used by model
|
||||||
|
- [ ] Add test: PORT in global macro used by model
|
||||||
|
- [ ] Add test: model macro overrides global macro in substitution
|
||||||
|
- [ ] Add test: self-reference detection error
|
||||||
|
- [ ] Add test: undefined macro reference error
|
||||||
|
- [ ] Verify all existing macro tests pass: `TestConfig_Macro*`
|
||||||
|
- [ ] Run `make test-all` - ensure all tests including concurrency tests pass
|
||||||
|
|
||||||
|
### Phase 5: Documentation
|
||||||
|
- [ ] Update plan status in this file (mark completed)
|
||||||
|
- [ ] Update CLAUDE.md if macro behavior needs documentation
|
||||||
|
- [ ] Verify no new error messages need user documentation
|
||||||
|
|
||||||
|
## Bug Example (Original Issue)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
macros:
|
||||||
|
"podman-llama": >
|
||||||
|
podman run --name ${MODEL_ID}
|
||||||
|
--init --rm -p ${PORT}:8080 -v /home/alex/ai/models:/models:z --gpus=all
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server-cuda
|
||||||
|
|
||||||
|
"standard-options": >
|
||||||
|
--no-mmap --jinja
|
||||||
|
|
||||||
|
"kv8": >
|
||||||
|
-fa on -ctk q8_0 -ctv q8_0
|
||||||
|
```
|
||||||
|
|
||||||
|
**Current Bug:**
|
||||||
|
- During macro substitution, if `${MODEL_ID}` is processed before `${podman-llama}`, the `${MODEL_ID}` reference inside `podman-llama` remains unsubstituted
|
||||||
|
- Results in error: `unknown macro '${MODEL_ID}' found in model.cmd`
|
||||||
|
|
||||||
|
**After Fix:**
|
||||||
|
- Macros substituted in LIFO order: `kv8` → `standard-options` → `podman-llama`
|
||||||
|
- `MODEL_ID` is a reserved macro, substituted last (after all user macros)
|
||||||
|
- `${MODEL_ID}` inside `podman-llama` is correctly replaced with the model name
|
||||||
@@ -1,88 +1,295 @@
|
|||||||
# Seconds to wait for llama.cpp to be available to serve requests
|
# llama-swap YAML configuration example
|
||||||
# Default (and minimum): 15 seconds
|
# -------------------------------------
|
||||||
healthCheckTimeout: 90
|
#
|
||||||
|
# 💡 Tip - Use an LLM with this file!
|
||||||
|
# ====================================
|
||||||
|
# This example configuration is written to be LLM friendly. Try
|
||||||
|
# copying this file into an LLM and asking it to explain or generate
|
||||||
|
# sections for you.
|
||||||
|
# ====================================
|
||||||
|
|
||||||
# valid log levels: debug, info (default), warn, error
|
# Usage notes:
|
||||||
logLevel: debug
|
# - Below are all the available configuration options for llama-swap.
|
||||||
|
# - Settings noted as "required" must be in your configuration file
|
||||||
|
# - Settings noted as "optional" can be omitted
|
||||||
|
|
||||||
# creating a coding profile with models for code generation and general questions
|
# healthCheckTimeout: number of seconds to wait for a model to be ready to serve requests
|
||||||
groups:
|
# - optional, default: 120
|
||||||
coding:
|
# - minimum value is 15 seconds, anything less will be set to this value
|
||||||
swap: false
|
healthCheckTimeout: 500
|
||||||
members:
|
|
||||||
- "qwen"
|
|
||||||
- "llama"
|
|
||||||
|
|
||||||
|
# logLevel: sets the logging value
|
||||||
|
# - optional, default: info
|
||||||
|
# - Valid log levels: debug, info, warn, error
|
||||||
|
logLevel: info
|
||||||
|
|
||||||
|
# metricsMaxInMemory: maximum number of metrics to keep in memory
|
||||||
|
# - optional, default: 1000
|
||||||
|
# - controls how many metrics are stored in memory before older ones are discarded
|
||||||
|
# - useful for limiting memory usage when processing large volumes of metrics
|
||||||
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
|
# startPort: sets the starting port number for the automatic ${PORT} macro.
|
||||||
|
# - optional, default: 5800
|
||||||
|
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
|
||||||
|
# - it is automatically incremented for every model that uses it
|
||||||
|
startPort: 10001
|
||||||
|
|
||||||
|
# macros: a dictionary of string substitutions
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - macros are reusable snippets
|
||||||
|
# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
|
||||||
|
# - useful for reducing common configuration settings
|
||||||
|
# - macro names are strings and must be less than 64 characters
|
||||||
|
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
|
||||||
|
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||||
|
# - macro values can be numbers, bools, or strings
|
||||||
|
# - macros can contain other macros, but they must be defined before they are used
|
||||||
|
macros:
|
||||||
|
# Example of a multi-line macro
|
||||||
|
"latest-llama": >
|
||||||
|
/path/to/llama-server/llama-server-ec9e0301
|
||||||
|
--port ${PORT}
|
||||||
|
|
||||||
|
"default_ctx": 4096
|
||||||
|
|
||||||
|
# Example of macro-in-macro usage. macros can contain other macros
|
||||||
|
# but they must be previously declared.
|
||||||
|
"default_args": "--ctx-size ${default_ctx}"
|
||||||
|
|
||||||
|
# models: a dictionary of model configurations
|
||||||
|
# - required
|
||||||
|
# - each key is the model's ID, used in API requests
|
||||||
|
# - model settings have default values that are used if they are not defined here
|
||||||
|
# - the model's ID is available in the ${MODEL_ID} macro, also available in macros defined above
|
||||||
|
# - below are examples of the all the settings a model can have
|
||||||
models:
|
models:
|
||||||
|
|
||||||
|
# keys are the model names used in API requests
|
||||||
"llama":
|
"llama":
|
||||||
cmd: >
|
# macros: a dictionary of string substitutions specific to this model
|
||||||
models/llama-server-osx
|
# - optional, default: empty dictionary
|
||||||
--port ${PORT}
|
# - macros defined here override macros defined in the global macros section
|
||||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
# - model level macros follow the same rules as global macros
|
||||||
|
macros:
|
||||||
|
"default_ctx": 16384
|
||||||
|
"temp": 0.7
|
||||||
|
|
||||||
# list of model name aliases this llama.cpp instance can serve
|
# cmd: the command to run to start the inference server.
|
||||||
|
# - required
|
||||||
|
# - it is just a string, similar to what you would run on the CLI
|
||||||
|
# - using `|` allows for comments in the command, these will be parsed out
|
||||||
|
# - macros can be used within cmd
|
||||||
|
cmd: |
|
||||||
|
# ${latest-llama} is a macro that is defined above
|
||||||
|
${latest-llama}
|
||||||
|
--model path/to/llama-8B-Q4_K_M.gguf
|
||||||
|
--ctx-size ${default_ctx}
|
||||||
|
--temperature ${temp}
|
||||||
|
|
||||||
|
# name: a display name for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
name: "llama 3.1 8B"
|
||||||
|
|
||||||
|
# description: a description for the model
|
||||||
|
# - optional, default: empty string
|
||||||
|
# - if set, it will be used in the v1/models API response
|
||||||
|
# - if not set, it will be omitted in the JSON model record
|
||||||
|
description: "A small but capable model used for quick testing"
|
||||||
|
|
||||||
|
# env: define an array of environment variables to inject into cmd's environment
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - each value is a single string
|
||||||
|
# - in the format: ENV_NAME=value
|
||||||
|
env:
|
||||||
|
- "CUDA_VISIBLE_DEVICES=0,1,2"
|
||||||
|
|
||||||
|
# proxy: the URL where llama-swap routes API requests
|
||||||
|
# - optional, default: http://localhost:${PORT}
|
||||||
|
# - if you used ${PORT} in cmd this can be omitted
|
||||||
|
# - if you use a custom port in cmd this *must* be set
|
||||||
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
|
# aliases: alternative model names that this model configuration is used for
|
||||||
|
# - optional, default: empty array
|
||||||
|
# - aliases must be unique globally
|
||||||
|
# - useful for impersonating a specific model
|
||||||
aliases:
|
aliases:
|
||||||
- gpt-4o-mini
|
- "gpt-4o-mini"
|
||||||
|
- "gpt-3.5-turbo"
|
||||||
|
|
||||||
# check this path for a HTTP 200 response for the server to be ready
|
# checkEndpoint: URL path to check if the server is ready
|
||||||
checkEndpoint: /health
|
# - optional, default: /health
|
||||||
|
# - endpoint is expected to return an HTTP 200 response
|
||||||
|
# - all requests wait until the endpoint is ready or fails
|
||||||
|
# - use "none" to skip endpoint health checking
|
||||||
|
checkEndpoint: /custom-endpoint
|
||||||
|
|
||||||
# unload model after 5 seconds
|
# ttl: automatically unload the model after ttl seconds
|
||||||
ttl: 5
|
# - optional, default: 0
|
||||||
|
# - ttl values must be a value greater than 0
|
||||||
|
# - a value of 0 disables automatic unloading of the model
|
||||||
|
ttl: 60
|
||||||
|
|
||||||
"qwen":
|
# useModelName: override the model name that is sent to upstream server
|
||||||
cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
# - optional, default: ""
|
||||||
aliases:
|
# - useful for when the upstream server expects a specific model name that
|
||||||
- gpt-3.5-turbo
|
# is different from the model's ID
|
||||||
|
useModelName: "qwen:qwq"
|
||||||
|
|
||||||
# Embedding example with Nomic
|
# filters: a dictionary of filter settings
|
||||||
# https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF
|
# - optional, default: empty dictionary
|
||||||
"nomic":
|
# - only stripParams is currently supported
|
||||||
cmd: >
|
filters:
|
||||||
models/llama-server-osx --port ${PORT}
|
# stripParams: a comma separated list of parameters to remove from the request
|
||||||
-m models/nomic-embed-text-v1.5.Q8_0.gguf
|
# - optional, default: ""
|
||||||
--ctx-size 8192
|
# - useful for server side enforcement of sampling parameters
|
||||||
--batch-size 8192
|
# - the `model` parameter can never be removed
|
||||||
--rope-scaling yarn
|
# - can be any JSON key in the request body
|
||||||
--rope-freq-scale 0.75
|
# - recommended to stick to sampling parameters
|
||||||
-ngl 99
|
stripParams: "temperature, top_p, top_k"
|
||||||
--embeddings
|
|
||||||
|
|
||||||
# Reranking example with bge-reranker
|
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||||
# https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF
|
# - optional, default: empty dictionary
|
||||||
"bge-reranker":
|
# - while metadata can contains complex types it is recommended to keep it simple
|
||||||
cmd: >
|
# - metadata is only passed through in /v1/models responses
|
||||||
models/llama-server-osx --port ${PORT}
|
metadata:
|
||||||
-m models/bge-reranker-v2-m3-Q4_K_M.gguf
|
# port will remain an integer
|
||||||
--ctx-size 8192
|
port: ${PORT}
|
||||||
--reranking
|
|
||||||
|
|
||||||
# Docker Support (v26.1.4+ required!)
|
# the ${temp} macro will remain a float
|
||||||
"dockertest":
|
temperature: ${temp}
|
||||||
cmd: >
|
note: "The ${MODEL_ID} is running on port ${PORT} temp=${temp}, context=${default_ctx}"
|
||||||
docker run --name dockertest
|
|
||||||
|
a_list:
|
||||||
|
- 1
|
||||||
|
- 1.23
|
||||||
|
- "macros are OK in list and dictionary types: ${MODEL_ID}"
|
||||||
|
|
||||||
|
an_obj:
|
||||||
|
a: "1"
|
||||||
|
b: 2
|
||||||
|
# objects can contain complex types with macro substitution
|
||||||
|
# becomes: c: [0.7, false, "model: llama"]
|
||||||
|
c: ["${temp}", false, "model: ${MODEL_ID}"]
|
||||||
|
|
||||||
|
# concurrencyLimit: overrides the allowed number of active parallel requests to a model
|
||||||
|
# - optional, default: 0
|
||||||
|
# - useful for limiting the number of active parallel requests a model can process
|
||||||
|
# - must be set per model
|
||||||
|
# - any number greater than 0 will override the internal default value of 10
|
||||||
|
# - any requests that exceeds the limit will receive an HTTP 429 Too Many Requests response
|
||||||
|
# - recommended to be omitted and the default used
|
||||||
|
concurrencyLimit: 0
|
||||||
|
|
||||||
|
# Unlisted model example:
|
||||||
|
"qwen-unlisted":
|
||||||
|
# unlisted: boolean, true or false
|
||||||
|
# - optional, default: false
|
||||||
|
# - unlisted models do not show up in /v1/models api requests
|
||||||
|
# - can be requested as normal through all apis
|
||||||
|
unlisted: true
|
||||||
|
cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0
|
||||||
|
|
||||||
|
# Docker example:
|
||||||
|
# container runtimes like Docker and Podman can be used reliably with
|
||||||
|
# a combination of cmd, cmdStop, and ${MODEL_ID}
|
||||||
|
"docker-llama":
|
||||||
|
proxy: "http://127.0.0.1:${PORT}"
|
||||||
|
cmd: |
|
||||||
|
docker run --name ${MODEL_ID}
|
||||||
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
--init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models
|
||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
"simple":
|
# cmdStop: command to run to stop the model gracefully
|
||||||
# example of setting environment variables
|
# - optional, default: ""
|
||||||
env:
|
# - useful for stopping commands managed by another system
|
||||||
- CUDA_VISIBLE_DEVICES=0,1
|
# - the upstream's process id is available in the ${PID} macro
|
||||||
- env1=hello
|
#
|
||||||
cmd: build/simple-responder --port ${PORT}
|
# When empty, llama-swap has this default behaviour:
|
||||||
unlisted: true
|
# - on POSIX systems: a SIGTERM signal is sent
|
||||||
|
# - on Windows, calls taskkill to stop the process
|
||||||
|
# - processes have 5 seconds to shutdown until forceful termination is attempted
|
||||||
|
cmdStop: docker stop ${MODEL_ID}
|
||||||
|
|
||||||
# use "none" to skip check. Caution this may cause some requests to fail
|
# groups: a dictionary of group settings
|
||||||
# until the upstream server is ready for traffic
|
# - optional, default: empty dictionary
|
||||||
checkEndpoint: none
|
# - provides advanced controls over model swapping behaviour
|
||||||
|
# - using groups some models can be kept loaded indefinitely, while others are swapped out
|
||||||
|
# - model IDs must be defined in the Models section
|
||||||
|
# - a model can only be a member of one group
|
||||||
|
# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields
|
||||||
|
# - see issue #109 for details
|
||||||
|
#
|
||||||
|
# NOTE: the example below uses model names that are not defined above for demonstration purposes
|
||||||
|
groups:
|
||||||
|
# group1 works the same as the default behaviour of llama-swap where only one model is allowed
|
||||||
|
# to run a time across the whole llama-swap instance
|
||||||
|
"group1":
|
||||||
|
# swap: controls the model swapping behaviour in within the group
|
||||||
|
# - optional, default: true
|
||||||
|
# - true : only one model is allowed to run at a time
|
||||||
|
# - false: all models can run together, no swapping
|
||||||
|
swap: true
|
||||||
|
|
||||||
# don't use these, just for testing if things are broken
|
# exclusive: controls how the group affects other groups
|
||||||
"broken":
|
# - optional, default: true
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf
|
# - true: causes all other groups to unload when this group runs a model
|
||||||
proxy: http://127.0.0.1:8999
|
# - false: does not affect other groups
|
||||||
unlisted: true
|
exclusive: true
|
||||||
"broken_timeout":
|
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
# members references the models defined above
|
||||||
proxy: http://127.0.0.1:9000
|
# required
|
||||||
unlisted: true
|
members:
|
||||||
|
- "llama"
|
||||||
|
- "qwen-unlisted"
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# - in group2 all models can run at the same time
|
||||||
|
# - when a different group is loaded it causes all running models in this group to unload
|
||||||
|
"group2":
|
||||||
|
swap: false
|
||||||
|
|
||||||
|
# exclusive: false does not unload other groups when a model in group2 is requested
|
||||||
|
# - the models in group2 will be loaded but will not unload any other groups
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "docker-llama"
|
||||||
|
- "modelA"
|
||||||
|
- "modelB"
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# - a persistent group, prevents other groups from unloading it
|
||||||
|
"forever":
|
||||||
|
# persistent: prevents over groups from unloading the models in this group
|
||||||
|
# - optional, default: false
|
||||||
|
# - does not affect individual model behaviour
|
||||||
|
persistent: true
|
||||||
|
|
||||||
|
# set swap/exclusive to false to prevent swapping inside the group
|
||||||
|
# and the unloading of other groups
|
||||||
|
swap: false
|
||||||
|
exclusive: false
|
||||||
|
members:
|
||||||
|
- "forever-modelA"
|
||||||
|
- "forever-modelB"
|
||||||
|
- "forever-modelc"
|
||||||
|
|
||||||
|
# hooks: a dictionary of event triggers and actions
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - the only supported hook is on_startup
|
||||||
|
hooks:
|
||||||
|
# on_startup: a dictionary of actions to perform on startup
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - the only supported action is preload
|
||||||
|
on_startup:
|
||||||
|
# preload: a list of model ids to load on startup
|
||||||
|
# - optional, default: empty list
|
||||||
|
# - model names must match keys in the models sections
|
||||||
|
# - when preloading multiple models at once, define a group
|
||||||
|
# otherwise models will be loaded and swapped out
|
||||||
|
preload:
|
||||||
|
- "llama"
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
healthCheckTimeout: 300
|
healthCheckTimeout: 300
|
||||||
logRequests: true
|
logRequests: true
|
||||||
|
metricsMaxInMemory: 1000
|
||||||
|
|
||||||
models:
|
models:
|
||||||
"qwen2.5":
|
"qwen2.5":
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
|
||||||
|
|
||||||
|
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default initializes a default in-process dispatcher
|
||||||
|
var Default = NewDispatcherConfig(25000)
|
||||||
|
|
||||||
|
// On subscribes to an event, the type of the event will be automatically
|
||||||
|
// inferred from the provided type. Must be constant for this to work. This
|
||||||
|
// functions same way as Subscribe() but uses the default dispatcher instead.
|
||||||
|
func On[T Event](handler func(T)) context.CancelFunc {
|
||||||
|
return Subscribe(Default, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnType subscribes to an event with the specified event type. This functions
|
||||||
|
// same way as SubscribeTo() but uses the default dispatcher instead.
|
||||||
|
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
|
||||||
|
return SubscribeTo(Default, eventType, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit writes an event into the dispatcher. This functions same way as
|
||||||
|
// Publish() but uses the default dispatcher instead.
|
||||||
|
func Emit[T Event](ev T) {
|
||||||
|
Publish(Default, ev)
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
|
||||||
|
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
|
||||||
|
*/
|
||||||
|
func BenchmarkSubscribeConcurrent(b *testing.B) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
unsub := Subscribe(d, func(ev MyEvent1) {})
|
||||||
|
unsub()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultPublish(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
var count int64
|
||||||
|
defer On(func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&count, 1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
defer OnType(TypeEvent1, func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&count, 1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Publish
|
||||||
|
wg.Add(4)
|
||||||
|
Emit(MyEvent1{})
|
||||||
|
Emit(MyEvent1{})
|
||||||
|
|
||||||
|
// Wait and check
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, int64(4), count)
|
||||||
|
}
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for details.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event represents an event contract
|
||||||
|
type Event interface {
|
||||||
|
Type() uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// registry holds an immutable sorted array of event mappings
|
||||||
|
type registry struct {
|
||||||
|
keys []uint32 // Event types (sorted)
|
||||||
|
grps []any // Corresponding subscribers
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Dispatcher -------------------------------------
|
||||||
|
|
||||||
|
// Dispatcher represents an event dispatcher.
|
||||||
|
type Dispatcher struct {
|
||||||
|
subs atomic.Pointer[registry] // Atomic pointer to immutable array
|
||||||
|
done chan struct{} // Cancellation
|
||||||
|
maxQueue int // Maximum queue size per consumer
|
||||||
|
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDispatcher creates a new dispatcher of events.
|
||||||
|
func NewDispatcher() *Dispatcher {
|
||||||
|
return NewDispatcherConfig(50000)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
|
||||||
|
func NewDispatcherConfig(maxQueue int) *Dispatcher {
|
||||||
|
d := &Dispatcher{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
maxQueue: maxQueue,
|
||||||
|
}
|
||||||
|
|
||||||
|
d.subs.Store(®istry{
|
||||||
|
keys: make([]uint32, 0, 16),
|
||||||
|
grps: make([]any, 0, 16),
|
||||||
|
})
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the dispatcher
|
||||||
|
func (d *Dispatcher) Close() error {
|
||||||
|
close(d.done)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClosed returns whether the dispatcher is closed or not
|
||||||
|
func (d *Dispatcher) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-d.done:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findGroup performs a lock-free binary search for the event type
|
||||||
|
func (d *Dispatcher) findGroup(eventType uint32) any {
|
||||||
|
reg := d.subs.Load()
|
||||||
|
keys := reg.keys
|
||||||
|
|
||||||
|
// Inlined binary search for better cache locality
|
||||||
|
left, right := 0, len(keys)
|
||||||
|
for left < right {
|
||||||
|
mid := left + (right-left)/2
|
||||||
|
if keys[mid] < eventType {
|
||||||
|
left = mid + 1
|
||||||
|
} else {
|
||||||
|
right = mid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if left < len(keys) && keys[left] == eventType {
|
||||||
|
return reg.grps[left]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes to an event, the type of the event will be automatically
|
||||||
|
// inferred from the provided type. Must be constant for this to work.
|
||||||
|
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
|
||||||
|
var event T
|
||||||
|
return SubscribeTo(broker, event.Type(), handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscribeTo subscribes to an event with the specified event type.
|
||||||
|
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
|
||||||
|
if broker.isClosed() {
|
||||||
|
panic(errClosed)
|
||||||
|
}
|
||||||
|
|
||||||
|
broker.mu.Lock()
|
||||||
|
defer broker.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if group already exists
|
||||||
|
if existing := broker.findGroup(eventType); existing != nil {
|
||||||
|
grp := groupOf[T](eventType, existing)
|
||||||
|
sub := grp.Add(handler)
|
||||||
|
return func() {
|
||||||
|
grp.Del(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new group
|
||||||
|
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
|
||||||
|
sub := grp.Add(handler)
|
||||||
|
|
||||||
|
// Copy-on-write: insert new entry in sorted position
|
||||||
|
old := broker.subs.Load()
|
||||||
|
idx := sort.Search(len(old.keys), func(i int) bool {
|
||||||
|
return old.keys[i] >= eventType
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create new arrays with space for one more element
|
||||||
|
newKeys := make([]uint32, len(old.keys)+1)
|
||||||
|
newGrps := make([]any, len(old.grps)+1)
|
||||||
|
|
||||||
|
// Copy elements before insertion point
|
||||||
|
copy(newKeys[:idx], old.keys[:idx])
|
||||||
|
copy(newGrps[:idx], old.grps[:idx])
|
||||||
|
|
||||||
|
// Insert new element
|
||||||
|
newKeys[idx] = eventType
|
||||||
|
newGrps[idx] = grp
|
||||||
|
|
||||||
|
// Copy elements after insertion point
|
||||||
|
copy(newKeys[idx+1:], old.keys[idx:])
|
||||||
|
copy(newGrps[idx+1:], old.grps[idx:])
|
||||||
|
|
||||||
|
// Atomically store the new registry (mutex ensures no concurrent writers)
|
||||||
|
newReg := ®istry{keys: newKeys, grps: newGrps}
|
||||||
|
broker.subs.Store(newReg)
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
grp.Del(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish writes an event into the dispatcher
|
||||||
|
func Publish[T Event](broker *Dispatcher, ev T) {
|
||||||
|
eventType := ev.Type()
|
||||||
|
if sub := broker.findGroup(eventType); sub != nil {
|
||||||
|
group := groupOf[T](eventType, sub)
|
||||||
|
group.Broadcast(ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count counts the number of subscribers, this is for testing only.
|
||||||
|
func (d *Dispatcher) count(eventType uint32) int {
|
||||||
|
if group := d.findGroup(eventType); group != nil {
|
||||||
|
return group.(interface{ Count() int }).Count()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupOf casts the subscriber group to the specified generic type
|
||||||
|
func groupOf[T Event](eventType uint32, subs any) *group[T] {
|
||||||
|
if group, ok := subs.(*group[T]); ok {
|
||||||
|
return group
|
||||||
|
}
|
||||||
|
|
||||||
|
panic(errConflict[T](eventType, subs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Subscriber -------------------------------------
|
||||||
|
|
||||||
|
// consumer represents a consumer with a message queue
|
||||||
|
type consumer[T Event] struct {
|
||||||
|
queue []T // Current work queue
|
||||||
|
stop bool // Stop signal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen listens to the event queue and processes events
|
||||||
|
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
|
||||||
|
pending := make([]T, 0, 128)
|
||||||
|
|
||||||
|
for {
|
||||||
|
c.L.Lock()
|
||||||
|
for len(s.queue) == 0 {
|
||||||
|
switch {
|
||||||
|
case s.stop:
|
||||||
|
c.L.Unlock()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap buffers and reset the current queue
|
||||||
|
temp := s.queue
|
||||||
|
s.queue = pending[:0]
|
||||||
|
pending = temp
|
||||||
|
c.L.Unlock()
|
||||||
|
|
||||||
|
// Outside of the critical section, process the work
|
||||||
|
for _, event := range pending {
|
||||||
|
fn(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify potential publishers waiting due to backpressure
|
||||||
|
c.Broadcast()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Subscriber Group -------------------------------------
|
||||||
|
|
||||||
|
// group represents a consumer group
|
||||||
|
type group[T Event] struct {
|
||||||
|
cond *sync.Cond
|
||||||
|
subs []*consumer[T]
|
||||||
|
maxQueue int // Maximum queue size per consumer
|
||||||
|
maxLen int // Current maximum queue length across all consumers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast sends an event to all consumers
|
||||||
|
func (s *group[T]) Broadcast(ev T) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Calculate current maximum queue length
|
||||||
|
s.maxLen = 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
if len(sub.queue) > s.maxLen {
|
||||||
|
s.maxLen = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backpressure: wait if queues are full
|
||||||
|
for s.maxLen >= s.maxQueue {
|
||||||
|
s.cond.Wait()
|
||||||
|
|
||||||
|
// Recalculate after wakeup
|
||||||
|
s.maxLen = 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
if len(sub.queue) > s.maxLen {
|
||||||
|
s.maxLen = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add event to all queues and track new maximum
|
||||||
|
newMax := 0
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
sub.queue = append(sub.queue, ev)
|
||||||
|
if len(sub.queue) > newMax {
|
||||||
|
newMax = len(sub.queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.maxLen = newMax
|
||||||
|
s.cond.Broadcast() // Wake consumers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a subscriber to the list
|
||||||
|
func (s *group[T]) Add(handler func(T)) *consumer[T] {
|
||||||
|
sub := &consumer[T]{
|
||||||
|
queue: make([]T, 0, 64),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the consumer to the list of active consumers
|
||||||
|
s.cond.L.Lock()
|
||||||
|
s.subs = append(s.subs, sub)
|
||||||
|
s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Start listening
|
||||||
|
go sub.Listen(s.cond, handler)
|
||||||
|
return sub
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del removes a subscriber from the list
|
||||||
|
func (s *group[T]) Del(sub *consumer[T]) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
|
||||||
|
// Search and remove the subscriber
|
||||||
|
sub.stop = true
|
||||||
|
for i, v := range s.subs {
|
||||||
|
if v == sub {
|
||||||
|
copy(s.subs[i:], s.subs[i+1:])
|
||||||
|
s.subs = s.subs[:len(s.subs)-1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Debugging -------------------------------------
|
||||||
|
|
||||||
|
var errClosed = fmt.Errorf("event dispatcher is closed")
|
||||||
|
|
||||||
|
// Count returns the number of subscribers in this group
|
||||||
|
func (s *group[T]) Count() int {
|
||||||
|
return len(s.subs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation of the type
|
||||||
|
func (s *group[T]) String() string {
|
||||||
|
typ := reflect.TypeOf(s).String()
|
||||||
|
idx := strings.LastIndex(typ, "/")
|
||||||
|
typ = typ[idx+1 : len(typ)-1]
|
||||||
|
return typ
|
||||||
|
}
|
||||||
|
|
||||||
|
// errConflict returns a conflict message
|
||||||
|
func errConflict[T any](eventType uint32, existing any) string {
|
||||||
|
var want T
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
|
||||||
|
want, existing, eventType,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
|
||||||
|
|
||||||
|
package event
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPublish(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Subscribe, must be received in order
|
||||||
|
var count int64
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Publish
|
||||||
|
wg.Add(3)
|
||||||
|
Publish(d, MyEvent1{Number: 1})
|
||||||
|
Publish(d, MyEvent1{Number: 2})
|
||||||
|
Publish(d, MyEvent1{Number: 3})
|
||||||
|
|
||||||
|
// Wait and check
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, int64(3), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnsubscribe(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||||
|
unsubscribe := Subscribe(d, func(ev MyEvent1) {
|
||||||
|
// Nothing
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, 1, d.count(TypeEvent1))
|
||||||
|
unsubscribe()
|
||||||
|
assert.Equal(t, 0, d.count(TypeEvent1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrent(t *testing.T) {
|
||||||
|
const max = 1000000
|
||||||
|
var count int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
if current := atomic.AddInt64(&count, 1); current == max {
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
// Asynchronously publish
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < max; i++ {
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer Subscribe(d, func(ev MyEvent1) {
|
||||||
|
// Subscriber that does nothing
|
||||||
|
})()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, max, int(count))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscribeDifferentType(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublishDifferentType(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseDispatcher(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
|
||||||
|
|
||||||
|
assert.NoError(t, d.Close())
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix(t *testing.T) {
|
||||||
|
const amount = 1000
|
||||||
|
for _, subs := range []int{1, 10, 100} {
|
||||||
|
for _, topics := range []int{1, 10} {
|
||||||
|
expected := subs * topics * amount
|
||||||
|
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
|
||||||
|
var count atomic.Int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(expected)
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
for i := 0; i < subs; i++ {
|
||||||
|
for id := 0; id < topics; id++ {
|
||||||
|
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
|
||||||
|
count.Add(1)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < amount; n++ {
|
||||||
|
for id := 0; id < topics; id++ {
|
||||||
|
go Publish(d, MyEvent3{ID: id})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, expected, int(count.Load()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentSubscriptionRace(t *testing.T) {
|
||||||
|
// This test specifically targets the race condition that occurs when multiple
|
||||||
|
// goroutines try to subscribe to different event types simultaneously.
|
||||||
|
// Without the CAS loop, subscriptions could be lost due to registry corruption.
|
||||||
|
|
||||||
|
const numGoroutines = 100
|
||||||
|
const numEventTypes = 50
|
||||||
|
|
||||||
|
d := NewDispatcher()
|
||||||
|
defer d.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var receivedCount int64
|
||||||
|
var subscribedTypes sync.Map // Thread-safe map
|
||||||
|
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
// Start multiple goroutines that subscribe to different event types concurrently
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(goroutineID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Each goroutine subscribes to a unique event type
|
||||||
|
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
|
||||||
|
|
||||||
|
// Subscribe to the event type
|
||||||
|
SubscribeTo(d, eventType, func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(&receivedCount, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Record that this type was subscribed
|
||||||
|
subscribedTypes.Store(eventType, true)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all subscriptions to complete
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Count the number of unique event types subscribed
|
||||||
|
expectedTypes := 0
|
||||||
|
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||||
|
expectedTypes++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Small delay to ensure all subscriptions are fully processed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Publish events to each subscribed type
|
||||||
|
subscribedTypes.Range(func(key, value interface{}) bool {
|
||||||
|
eventType := key.(uint32)
|
||||||
|
Publish(d, MyEvent3{ID: int(eventType)})
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wait for all events to be processed
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify that we received at least the expected number of events
|
||||||
|
// (there might be more if multiple goroutines subscribed to the same event type)
|
||||||
|
received := atomic.LoadInt64(&receivedCount)
|
||||||
|
assert.GreaterOrEqual(t, int(received), expectedTypes,
|
||||||
|
"Should have received at least %d events, got %d", expectedTypes, received)
|
||||||
|
|
||||||
|
// Verify that we have the expected number of unique event types
|
||||||
|
assert.Equal(t, numEventTypes, expectedTypes,
|
||||||
|
"Should have exactly %d unique event types", numEventTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentHandlerRegistration(t *testing.T) {
|
||||||
|
const numGoroutines = 100
|
||||||
|
|
||||||
|
// Test concurrent subscriptions to the same event type
|
||||||
|
t.Run("SameEventType", func(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var handlerCount int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Start multiple goroutines subscribing to the same event type (0x1)
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
|
||||||
|
atomic.AddInt64(&handlerCount, 1)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Verify all handlers were registered by publishing an event
|
||||||
|
atomic.StoreInt64(&handlerCount, 0)
|
||||||
|
Publish(d, MyEvent1{})
|
||||||
|
|
||||||
|
// Small delay to ensure all handlers have executed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
|
||||||
|
"Not all handlers were registered due to race condition")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test concurrent subscriptions to different event types
|
||||||
|
t.Run("DifferentEventTypes", func(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
receivedEvents := make(map[uint32]*int64)
|
||||||
|
|
||||||
|
// Create multiple event types and subscribe concurrently
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
eventType := uint32(100 + i)
|
||||||
|
counter := new(int64)
|
||||||
|
receivedEvents[eventType] = counter
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(et uint32, cnt *int64) {
|
||||||
|
defer wg.Done()
|
||||||
|
SubscribeTo(d, et, func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(cnt, 1)
|
||||||
|
})
|
||||||
|
}(eventType, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Publish events to all types
|
||||||
|
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
|
||||||
|
Publish(d, MyEvent3{ID: int(eventType)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small delay to ensure all handlers have executed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify all event types received their events
|
||||||
|
for eventType, counter := range receivedEvents {
|
||||||
|
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
|
||||||
|
"Event type %d did not receive its event", eventType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackpressure(t *testing.T) {
|
||||||
|
d := NewDispatcher()
|
||||||
|
d.maxQueue = 10
|
||||||
|
|
||||||
|
var processedCount int64
|
||||||
|
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
|
||||||
|
atomic.AddInt64(&processedCount, 1)
|
||||||
|
})
|
||||||
|
defer unsub()
|
||||||
|
|
||||||
|
const eventsToPublish = 1000
|
||||||
|
for i := 0; i < eventsToPublish; i++ {
|
||||||
|
Publish(d, MyEvent3{ID: 0x200})
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify all events were eventually processed
|
||||||
|
finalProcessed := atomic.LoadInt64(&processedCount)
|
||||||
|
assert.Equal(t, int64(eventsToPublish), finalProcessed)
|
||||||
|
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------- Test Events -------------------------------------
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeEvent1 = 0x1
|
||||||
|
TypeEvent2 = 0x2
|
||||||
|
)
|
||||||
|
|
||||||
|
type MyEvent1 struct {
|
||||||
|
Number int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
|
||||||
|
|
||||||
|
type MyEvent2 struct {
|
||||||
|
Text string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
|
||||||
|
|
||||||
|
type MyEvent3 struct {
|
||||||
|
ID int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
|
||||||
@@ -3,6 +3,7 @@ module github.com/mostlygeek/llama-swap
|
|||||||
go 1.23.0
|
go 1.23.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/billziss-gh/golib v0.2.0
|
||||||
github.com/fsnotify/fsnotify v1.9.0
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
@@ -12,7 +13,6 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/billziss-gh/golib v0.2.0 // indirect
|
|
||||||
github.com/bytedance/sonic v1.11.6 // indirect
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
|
|||||||
@@ -32,8 +32,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
|||||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
|
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ import (
|
|||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -37,13 +39,13 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := proxy.LoadConfig(*configPath)
|
conf, err := config.LoadConfig(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error loading config: %v\n", err)
|
fmt.Printf("Error loading config: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.Profiles) > 0 {
|
if len(conf.Profiles) > 0 {
|
||||||
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,137 +55,135 @@ func main() {
|
|||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyManager := proxy.New(config)
|
|
||||||
|
|
||||||
// Setup channels for server management
|
// Setup channels for server management
|
||||||
reloadChan := make(chan *proxy.ProxyManager)
|
|
||||||
exitChan := make(chan struct{})
|
exitChan := make(chan struct{})
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
// Create server with initial handler
|
// Create server with initial handler
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: *listenStr,
|
Addr: *listenStr,
|
||||||
Handler: proxyManager,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Support for watching config and reloading when it changes
|
||||||
|
reloadProxyManager := func() {
|
||||||
|
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
|
conf, err = config.LoadConfig(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Configuration Changed")
|
||||||
|
currentPM.Shutdown()
|
||||||
|
srv.Handler = proxy.New(conf)
|
||||||
|
fmt.Println("Configuration Reloaded")
|
||||||
|
|
||||||
|
// wait a few seconds and tell any UI to reload
|
||||||
|
time.AfterFunc(3*time.Second, func() {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateEnd,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
conf, err = config.LoadConfig(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error, unable to load configuration: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
srv.Handler = proxy.New(conf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// load the initial proxy manager
|
||||||
|
reloadProxyManager()
|
||||||
|
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||||
|
if *watchConfig {
|
||||||
|
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||||
|
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||||
|
debouncedReload()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
fmt.Println("Watching Configuration for changes")
|
||||||
|
go func() {
|
||||||
|
absConfigPath, err := filepath.Abs(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
configDir := filepath.Dir(absConfigPath)
|
||||||
|
err = watcher.Add(configDir)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer watcher.Close()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case changeEvent := <-watcher.Events:
|
||||||
|
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
||||||
|
// the change for k8s configmap
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case err := <-watcher.Errors:
|
||||||
|
log.Printf("File watcher error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown on signal
|
||||||
|
go func() {
|
||||||
|
sig := <-sigChan
|
||||||
|
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
|
||||||
|
pm.Shutdown()
|
||||||
|
} else {
|
||||||
|
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
fmt.Printf("Server shutdown error: %v\n", err)
|
||||||
|
}
|
||||||
|
close(exitChan)
|
||||||
|
}()
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
fmt.Printf("llama-swap listening on %s\n", *listenStr)
|
||||||
go func() {
|
go func() {
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
fmt.Printf("Fatal server error: %v\n", err)
|
log.Fatalf("Fatal server error: %v\n", err)
|
||||||
close(exitChan)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Handle config reloads and signals
|
|
||||||
go func() {
|
|
||||||
currentManager := proxyManager
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case newManager := <-reloadChan:
|
|
||||||
log.Println("Config change detected, waiting for in-flight requests to complete...")
|
|
||||||
// Stop old manager processes gracefully (this waits for in-flight requests)
|
|
||||||
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
|
|
||||||
// Now do a full shutdown to clear the process map
|
|
||||||
currentManager.Shutdown()
|
|
||||||
currentManager = newManager
|
|
||||||
srv.Handler = newManager
|
|
||||||
log.Println("Server handler updated with new config")
|
|
||||||
case sig := <-sigChan:
|
|
||||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
currentManager.Shutdown()
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
|
||||||
fmt.Printf("Server shutdown error: %v\n", err)
|
|
||||||
}
|
|
||||||
close(exitChan)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Start file watcher if requested
|
|
||||||
if *watchConfig {
|
|
||||||
absConfigPath, err := filepath.Abs(*configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
|
|
||||||
} else {
|
|
||||||
go watchConfigFileWithReload(absConfigPath, reloadChan)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for exit signal
|
// Wait for exit signal
|
||||||
<-exitChan
|
<-exitChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
|
func debounce(interval time.Duration, f func()) func() {
|
||||||
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
|
var timer *time.Timer
|
||||||
watcher, err := fsnotify.NewWatcher()
|
return func() {
|
||||||
if err != nil {
|
if timer != nil {
|
||||||
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
|
timer.Stop()
|
||||||
return
|
|
||||||
}
|
|
||||||
defer watcher.Close()
|
|
||||||
|
|
||||||
err = watcher.Add(configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Watching config file for changes: %s", configPath)
|
|
||||||
|
|
||||||
var debounceTimer *time.Timer
|
|
||||||
debounceDuration := 2 * time.Second
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case event, ok := <-watcher.Events:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// We only care about writes to the specific config file
|
|
||||||
if event.Name == configPath && event.Has(fsnotify.Write) {
|
|
||||||
// Reset or start the debounce timer
|
|
||||||
if debounceTimer != nil {
|
|
||||||
debounceTimer.Stop()
|
|
||||||
}
|
|
||||||
debounceTimer = time.AfterFunc(debounceDuration, func() {
|
|
||||||
log.Printf("Config file modified: %s, reloading...", event.Name)
|
|
||||||
|
|
||||||
// Try up to 3 times with exponential backoff
|
|
||||||
var newConfig proxy.Config
|
|
||||||
var err error
|
|
||||||
for retries := 0; retries < 3; retries++ {
|
|
||||||
// Load new configuration
|
|
||||||
newConfig, err = proxy.LoadConfig(configPath)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
|
|
||||||
if retries < 2 {
|
|
||||||
time.Sleep(time.Duration(1<<retries) * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to load new config after retries: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new ProxyManager with new config
|
|
||||||
newPM := proxy.New(newConfig)
|
|
||||||
reloadChan <- newPM
|
|
||||||
log.Println("Config reloaded successfully")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case err, ok := <-watcher.Errors:
|
|
||||||
if !ok {
|
|
||||||
log.Println("File watcher error channel closed.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("File watcher error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
timer = time.AfterFunc(interval, f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// created for issue: #252 https://github.com/mostlygeek/llama-swap/issues/252
|
||||||
|
// this simple benchmark tool sends a lot of small chat completion requests to llama-swap
|
||||||
|
// to make sure all the requests are accounted for.
|
||||||
|
//
|
||||||
|
// requests can be sent in parallel, and the tool will report the results.
|
||||||
|
// usage: go run main.go -baseurl http://localhost:8080/v1 -model llama3 -requests 1000 -par 5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// ----- CLI arguments ----------------------------------------------------
|
||||||
|
var (
|
||||||
|
baseurl string
|
||||||
|
modelName string
|
||||||
|
totalRequests int
|
||||||
|
parallelization int
|
||||||
|
)
|
||||||
|
|
||||||
|
flag.StringVar(&baseurl, "baseurl", "http://localhost:8080/v1", "Base URL of the API (e.g., https://api.example.com)")
|
||||||
|
flag.StringVar(&modelName, "model", "", "Model name to use")
|
||||||
|
flag.IntVar(&totalRequests, "requests", 1, "Total number of requests to send")
|
||||||
|
flag.IntVar(¶llelization, "par", 1, "Maximum number of concurrent requests")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if baseurl == "" || modelName == "" {
|
||||||
|
fmt.Println("Error: both -baseurl and -model are required.")
|
||||||
|
flag.Usage()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if totalRequests <= 0 {
|
||||||
|
fmt.Println("Error: -requests must be greater than 0.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if parallelization <= 0 {
|
||||||
|
fmt.Println("Error: -parallelization must be greater than 0.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- HTTP client -------------------------------------------------------
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Tracking response codes -------------------------------------------
|
||||||
|
statusCounts := make(map[int]int) // map[statusCode]count
|
||||||
|
var mu sync.Mutex // protects statusCounts
|
||||||
|
|
||||||
|
// ----- Request queue (buffered channel) ----------------------------------
|
||||||
|
requests := make(chan int, 10) // Buffered channel with capacity 10
|
||||||
|
|
||||||
|
// Goroutine to fill the request queue
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < totalRequests; i++ {
|
||||||
|
requests <- i + 1
|
||||||
|
}
|
||||||
|
close(requests)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// ----- Worker pool -------------------------------------------------------
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < parallelization; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(workerID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
for reqID := range requests {
|
||||||
|
// Build request payload as a single line JSON string
|
||||||
|
payload := `{"model":"` + modelName + `","max_tokens":100,"stream":false,"messages":[{"role":"user","content":"write a snake game in python"}]}`
|
||||||
|
|
||||||
|
// Send POST request
|
||||||
|
req, err := http.NewRequest(http.MethodPost,
|
||||||
|
fmt.Sprintf("%s/chat/completions", baseurl),
|
||||||
|
bytes.NewReader([]byte(payload)))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[worker %d][req %d] request creation error: %v", workerID, reqID, err)
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[-1]++
|
||||||
|
mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[worker %d][req %d] HTTP request error: %v", workerID, reqID, err)
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[-1]++
|
||||||
|
mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// Record status code
|
||||||
|
mu.Lock()
|
||||||
|
statusCounts[resp.StatusCode]++
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Status ticker (prints every second) -------------------------------
|
||||||
|
done := make(chan struct{})
|
||||||
|
tickerDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
startTime := time.Now()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
mu.Lock()
|
||||||
|
// Compute how many requests have completed so far
|
||||||
|
completed := 0
|
||||||
|
for _, cnt := range statusCounts {
|
||||||
|
completed += cnt
|
||||||
|
}
|
||||||
|
// Calculate duration and progress
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
progress := completed * 100 / totalRequests
|
||||||
|
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, progress)
|
||||||
|
mu.Unlock()
|
||||||
|
case <-done:
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
fmt.Printf("Duration: %v, Completed: %d%% requests\n", duration, 100)
|
||||||
|
close(tickerDone)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for all workers to finish
|
||||||
|
wg.Wait()
|
||||||
|
close(done) // stops the status-update goroutine
|
||||||
|
<-tickerDone // give ticker time to finish / print
|
||||||
|
|
||||||
|
// ----- Summary ------------------------------------------------------------
|
||||||
|
fmt.Println("\n\n=== HTTP response code summary ===")
|
||||||
|
mu.Lock()
|
||||||
|
for code, cnt := range statusCounts {
|
||||||
|
if code == -1 {
|
||||||
|
fmt.Printf("Client-side errors (no HTTP response): %d\n", cnt)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%d : %d\n", code, cnt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
**
|
||||||
|
Test how exec.Cmd.CommandContext behaves under certain conditions:*
|
||||||
|
|
||||||
|
- process is killed externally, what happens with cmd.Wait() *
|
||||||
|
✔︎ it returns. catches crashes.*
|
||||||
|
|
||||||
|
- process ignores SIGTERM*
|
||||||
|
✔︎ `kill()` is called after cmd.WaitDelay*
|
||||||
|
|
||||||
|
- this process exits, what happens with children (kill -9 <this process' pid>)*
|
||||||
|
x they stick around. have to be manually killed.*
|
||||||
|
|
||||||
|
- .WithTimeout()'s cancel is called *
|
||||||
|
✔︎ process is killed after it ignores sigterm, cmd.Wait() catches it.*
|
||||||
|
|
||||||
|
- parent receives SIGINT/SIGTERM, what happens
|
||||||
|
✔︎ waits for child process to exit, then exits gracefully.
|
||||||
|
*/
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
// swap between these to use kill -9 <pid> on the cli to sim external crash
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
//ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
//cmd := exec.CommandContext(ctx, "sleep", "1")
|
||||||
|
cmd := exec.CommandContext(ctx,
|
||||||
|
"../../build/simple-responder_darwin_arm64",
|
||||||
|
//"-ignore-sig-term", /* so it doesn't exit on receiving SIGTERM, test cmd.WaitTimeout */
|
||||||
|
)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
// set a wait delay before signing sig kill
|
||||||
|
cmd.WaitDelay = 500 * time.Millisecond
|
||||||
|
cmd.Cancel = func() error {
|
||||||
|
fmt.Println("✔︎ Cancel() called, sending SIGTERM")
|
||||||
|
cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
|
||||||
|
//return nil
|
||||||
|
|
||||||
|
// this error is returned by cmd.Wait(), and can be used to
|
||||||
|
// single an error when the process couldn't be normally terminated
|
||||||
|
// but since a SIGTERM is sent, it's probably ok to return a nil
|
||||||
|
// as WaitDelay timing out will override the any error set here.
|
||||||
|
//
|
||||||
|
// test by enabling/disabling -ignore-sig-term on the process
|
||||||
|
// with -ignore-sig-term enabled, cmd.Wait() will have "signal: killed"
|
||||||
|
// without it, it will show the "new error from cancel"
|
||||||
|
return errors.New("error from cmd.Cancel()") // sets error returned by cmd.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
fmt.Println("Error starting process:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// catch signals. Calls cancel() which will cause cmd.Wait() to return and
|
||||||
|
// this program to eventually exit gracefully.
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
signal := <-sigChan
|
||||||
|
fmt.Printf("✔︎ Received signal: %d, Killing process... with cancel before exiting\n", signal)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
fmt.Printf("✔︎ Parent Pid: %d, Process Pid: %d\n", os.Getpid(), cmd.Process.Pid)
|
||||||
|
fmt.Println("✔︎ Process started, cmd.Wait() ... ")
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
fmt.Println("✔︎ cmd.Wait returned, Error:", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("✔︎ cmd.Wait returned, Process exited on its own")
|
||||||
|
}
|
||||||
|
fmt.Println("✔︎ Child process exited, Done.")
|
||||||
|
}
|
||||||
@@ -35,17 +35,90 @@ func main() {
|
|||||||
|
|
||||||
// Set up the handler function using the provided response message
|
// Set up the handler function using the provided response message
|
||||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "application/json")
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
|
|
||||||
// add a wait to simulate a slow query
|
// Check if streaming is requested
|
||||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
// Query is checked instead of JSON body since that event stream conflicts with other tests
|
||||||
time.Sleep(wait)
|
isStreaming := c.Query("stream") == "true"
|
||||||
|
|
||||||
|
if isStreaming {
|
||||||
|
// Set headers for streaming
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
// add a wait to simulate a slow query
|
||||||
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
|
time.Sleep(wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 10 "asdf" tokens
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data := gin.H{
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"choices": []gin.H{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": gin.H{
|
||||||
|
"content": "asdf",
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.SSEvent("message", data)
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final data with usage info
|
||||||
|
finalData := gin.H{
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
// add timings to simulate llama.cpp
|
||||||
|
"timings": gin.H{
|
||||||
|
"prompt_n": 25,
|
||||||
|
"prompt_ms": 13,
|
||||||
|
"predicted_n": 10,
|
||||||
|
"predicted_ms": 17,
|
||||||
|
"predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.SSEvent("message", finalData)
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
// Send [DONE]
|
||||||
|
c.SSEvent("message", "[DONE]")
|
||||||
|
c.Writer.Flush()
|
||||||
|
} else {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// add a wait to simulate a slow query
|
||||||
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
|
time.Sleep(wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||||
|
"request_body": string(bodyBytes),
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
"timings": gin.H{
|
||||||
|
"prompt_n": 25,
|
||||||
|
"prompt_ms": 13,
|
||||||
|
"predicted_n": 10,
|
||||||
|
"predicted_ms": 17,
|
||||||
|
"predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"responseMessage": *responseMessage,
|
|
||||||
"h_content_length": c.Request.Header.Get("Content-Length"),
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// for issue #62 to check model name strips profile slug
|
// for issue #62 to check model name strips profile slug
|
||||||
@@ -71,10 +144,28 @@ func main() {
|
|||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"responseMessage": *responseMessage,
|
"responseMessage": *responseMessage,
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// llama-server compatibility: /completion
|
||||||
|
r.POST("/completion", func(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"usage": gin.H{
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
// issue #41
|
// issue #41
|
||||||
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||||
// Parse the multipart form
|
// Parse the multipart form
|
||||||
@@ -223,13 +314,13 @@ runloop:
|
|||||||
if countSigInt > 1 {
|
if countSigInt > 1 {
|
||||||
break runloop
|
break runloop
|
||||||
} else {
|
} else {
|
||||||
log.Println("Recieved SIGINT, send another SIGINT to shutdown")
|
log.Println("Received SIGINT, send another SIGINT to shutdown")
|
||||||
}
|
}
|
||||||
case syscall.SIGTERM:
|
case syscall.SIGTERM:
|
||||||
if *ignoreSigTerm {
|
if *ignoreSigTerm {
|
||||||
log.Println("Ignoring SIGTERM")
|
log.Println("Ignoring SIGTERM")
|
||||||
} else {
|
} else {
|
||||||
log.Println("Recieved SIGTERM, shutting down")
|
log.Println("Received SIGTERM, shutting down")
|
||||||
break runloop
|
break runloop
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
ui_dist/*
|
||||||
@@ -1,264 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/billziss-gh/golib/shlex"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DEFAULT_GROUP_ID = "(default)"
|
|
||||||
|
|
||||||
type ModelConfig struct {
|
|
||||||
Cmd string `yaml:"cmd"`
|
|
||||||
Proxy string `yaml:"proxy"`
|
|
||||||
Aliases []string `yaml:"aliases"`
|
|
||||||
Env []string `yaml:"env"`
|
|
||||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
|
||||||
UnloadAfter int `yaml:"ttl"`
|
|
||||||
Unlisted bool `yaml:"unlisted"`
|
|
||||||
UseModelName string `yaml:"useModelName"`
|
|
||||||
|
|
||||||
// Limit concurrency of HTTP requests to process
|
|
||||||
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|
||||||
return SanitizeCommand(m.Cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
type GroupConfig struct {
|
|
||||||
Swap bool `yaml:"swap"`
|
|
||||||
Exclusive bool `yaml:"exclusive"`
|
|
||||||
Persistent bool `yaml:"persistent"`
|
|
||||||
Members []string `yaml:"members"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// set default values for GroupConfig
|
|
||||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
||||||
type rawGroupConfig GroupConfig
|
|
||||||
defaults := rawGroupConfig{
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: true,
|
|
||||||
Persistent: false,
|
|
||||||
Members: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unmarshal(&defaults); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
*c = GroupConfig(defaults)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
|
||||||
LogRequests bool `yaml:"logRequests"`
|
|
||||||
LogLevel string `yaml:"logLevel"`
|
|
||||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
|
||||||
Profiles map[string][]string `yaml:"profiles"`
|
|
||||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
|
||||||
|
|
||||||
// map aliases to actual model IDs
|
|
||||||
aliases map[string]string
|
|
||||||
|
|
||||||
// automatic port assignments
|
|
||||||
StartPort int `yaml:"startPort"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
|
||||||
if _, found := c.Models[search]; found {
|
|
||||||
return search, true
|
|
||||||
} else if name, found := c.aliases[search]; found {
|
|
||||||
return name, found
|
|
||||||
} else {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
|
||||||
if realName, found := c.RealModelName(modelName); !found {
|
|
||||||
return ModelConfig{}, "", false
|
|
||||||
} else {
|
|
||||||
return c.Models[realName], realName, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfig(path string) (Config, error) {
|
|
||||||
file, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
return LoadConfigFromReader(file)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var config Config
|
|
||||||
err = yaml.Unmarshal(data, &config)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
|
||||||
config.HealthCheckTimeout = 15
|
|
||||||
}
|
|
||||||
|
|
||||||
// set default port ranges
|
|
||||||
if config.StartPort == 0 {
|
|
||||||
// default to 5800
|
|
||||||
config.StartPort = 5800
|
|
||||||
} else if config.StartPort < 1 {
|
|
||||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate the aliases map
|
|
||||||
config.aliases = make(map[string]string)
|
|
||||||
for modelName, modelConfig := range config.Models {
|
|
||||||
for _, alias := range modelConfig.Aliases {
|
|
||||||
if _, found := config.aliases[alias]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
|
||||||
}
|
|
||||||
config.aliases[alias] = modelName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// iterate over the models and replace any ${PORT} with the next available port
|
|
||||||
// Get and sort all model IDs first, makes testing more consistent
|
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
|
||||||
for modelId := range config.Models {
|
|
||||||
modelIds = append(modelIds, modelId)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
|
||||||
|
|
||||||
// iterate over the sorted models
|
|
||||||
nextPort := config.StartPort
|
|
||||||
for _, modelId := range modelIds {
|
|
||||||
modelConfig := config.Models[modelId]
|
|
||||||
if strings.Contains(modelConfig.Cmd, "${PORT}") {
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
|
|
||||||
if modelConfig.Proxy == "" {
|
|
||||||
modelConfig.Proxy = fmt.Sprintf("http://localhost:%d", nextPort)
|
|
||||||
} else {
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", strconv.Itoa(nextPort))
|
|
||||||
}
|
|
||||||
nextPort++
|
|
||||||
config.Models[modelId] = modelConfig
|
|
||||||
} else if modelConfig.Proxy == "" {
|
|
||||||
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
|
||||||
// check that members are all unique in the groups
|
|
||||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
|
||||||
for groupID, groupConfig := range config.Groups {
|
|
||||||
prevSet := make(map[string]bool)
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
// Check for duplicates within this group
|
|
||||||
if _, found := prevSet[member]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
|
||||||
}
|
|
||||||
prevSet[member] = true
|
|
||||||
|
|
||||||
// Check if member is used in another group
|
|
||||||
if existingGroup, exists := memberUsage[member]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
|
||||||
}
|
|
||||||
memberUsage[member] = groupID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewrites the yaml to include a default group with any orphaned models
|
|
||||||
func AddDefaultGroupToConfig(config Config) Config {
|
|
||||||
|
|
||||||
if config.Groups == nil {
|
|
||||||
config.Groups = make(map[string]GroupConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultGroup := GroupConfig{
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: true,
|
|
||||||
Members: []string{},
|
|
||||||
}
|
|
||||||
// if groups is empty, create a default group and put
|
|
||||||
// all models into it
|
|
||||||
if len(config.Groups) == 0 {
|
|
||||||
for modelName := range config.Models {
|
|
||||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// iterate over existing group members and add non-grouped models into the default group
|
|
||||||
for modelName, _ := range config.Models {
|
|
||||||
foundModel := false
|
|
||||||
found:
|
|
||||||
// search for the model in existing groups
|
|
||||||
for _, groupConfig := range config.Groups {
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
if member == modelName {
|
|
||||||
foundModel = true
|
|
||||||
break found
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !foundModel {
|
|
||||||
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
|
||||||
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
|
||||||
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Handle trailing backslashes by replacing with space
|
|
||||||
if strings.HasSuffix(trimmed, "\\") {
|
|
||||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
|
||||||
} else {
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// put it back together
|
|
||||||
cmdStr = strings.Join(cleanedLines, "\n")
|
|
||||||
|
|
||||||
// Split the command into arguments
|
|
||||||
var args []string
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
args = shlex.Windows.Split(cmdStr)
|
|
||||||
} else {
|
|
||||||
args = shlex.Posix.Split(cmdStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the command is not empty
|
|
||||||
if len(args) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty command")
|
|
||||||
}
|
|
||||||
|
|
||||||
return args, nil
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,593 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/billziss-gh/golib/shlex"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DEFAULT_GROUP_ID = "(default)"
|
||||||
|
|
||||||
|
type MacroEntry struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
|
||||||
|
type MacroList []MacroEntry
|
||||||
|
|
||||||
|
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
|
||||||
|
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
if value.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("macros must be a mapping")
|
||||||
|
}
|
||||||
|
|
||||||
|
// yaml.Node.Content for a mapping contains alternating key/value nodes
|
||||||
|
entries := make([]MacroEntry, 0, len(value.Content)/2)
|
||||||
|
for i := 0; i < len(value.Content); i += 2 {
|
||||||
|
keyNode := value.Content[i]
|
||||||
|
valueNode := value.Content[i+1]
|
||||||
|
|
||||||
|
var name string
|
||||||
|
if err := keyNode.Decode(&name); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode macro name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var val any
|
||||||
|
if err := valueNode.Decode(&val); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = append(entries, MacroEntry{Name: name, Value: val})
|
||||||
|
}
|
||||||
|
|
||||||
|
*ml = entries
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a macro value by name
|
||||||
|
func (ml MacroList) Get(name string) (any, bool) {
|
||||||
|
for _, entry := range ml {
|
||||||
|
if entry.Name == name {
|
||||||
|
return entry.Value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap converts MacroList to a map (for backward compatibility if needed)
|
||||||
|
func (ml MacroList) ToMap() map[string]any {
|
||||||
|
result := make(map[string]any, len(ml))
|
||||||
|
for _, entry := range ml {
|
||||||
|
result[entry.Name] = entry.Value
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroupConfig struct {
|
||||||
|
Swap bool `yaml:"swap"`
|
||||||
|
Exclusive bool `yaml:"exclusive"`
|
||||||
|
Persistent bool `yaml:"persistent"`
|
||||||
|
Members []string `yaml:"members"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
|
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// set default values for GroupConfig
|
||||||
|
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawGroupConfig GroupConfig
|
||||||
|
defaults := rawGroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Persistent: false,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = GroupConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HooksConfig struct {
|
||||||
|
OnStartup HookOnStartup `yaml:"on_startup"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HookOnStartup struct {
|
||||||
|
Preload []string `yaml:"preload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
|
LogRequests bool `yaml:"logRequests"`
|
||||||
|
LogLevel string `yaml:"logLevel"`
|
||||||
|
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||||
|
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||||
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||||
|
|
||||||
|
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
|
||||||
|
Macros MacroList `yaml:"macros"`
|
||||||
|
|
||||||
|
// map aliases to actual model IDs
|
||||||
|
aliases map[string]string
|
||||||
|
|
||||||
|
// automatic port assignments
|
||||||
|
StartPort int `yaml:"startPort"`
|
||||||
|
|
||||||
|
// hooks, see: #209
|
||||||
|
Hooks HooksConfig `yaml:"hooks"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
|
if _, found := c.Models[search]; found {
|
||||||
|
return search, true
|
||||||
|
} else if name, found := c.aliases[search]; found {
|
||||||
|
return name, found
|
||||||
|
} else {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||||
|
if realName, found := c.RealModelName(modelName); !found {
|
||||||
|
return ModelConfig{}, "", false
|
||||||
|
} else {
|
||||||
|
return c.Models[realName], realName, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfig(path string) (Config, error) {
|
||||||
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
return LoadConfigFromReader(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// default configuration values
|
||||||
|
config := Config{
|
||||||
|
HealthCheckTimeout: 120,
|
||||||
|
StartPort: 5800,
|
||||||
|
LogLevel: "info",
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
}
|
||||||
|
err = yaml.Unmarshal(data, &config)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HealthCheckTimeout < 15 {
|
||||||
|
// set a minimum of 15 seconds
|
||||||
|
config.HealthCheckTimeout = 15
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the aliases map
|
||||||
|
config.aliases = make(map[string]string)
|
||||||
|
for modelName, modelConfig := range config.Models {
|
||||||
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
|
config.aliases[alias] = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* check macro constraint rules:
|
||||||
|
|
||||||
|
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||||
|
- names must be less than 64 characters (no reason, just cause)
|
||||||
|
- name can not be any reserved macros: PORT, MODEL_ID
|
||||||
|
- macro values must be less than 1024 characters
|
||||||
|
*/
|
||||||
|
for _, macro := range config.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and sort all model IDs first, makes testing more consistent
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||||
|
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
|
||||||
|
// Strip comments from command fields before macro expansion
|
||||||
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
|
// validate model macros
|
||||||
|
for _, macro := range modelConfig.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge global config and model macros. Model macros take precedence
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
|
||||||
|
// Add global macros first
|
||||||
|
mergedMacros = append(mergedMacros, config.Macros...)
|
||||||
|
|
||||||
|
// Add model macros (can override global)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
// Remove any existing global macro with same name
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry // Override
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||||
|
// This allows later macros to reference earlier ones
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
// Substitute in command fields
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in metadata (recursive)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final pass: check if PORT macro is needed after macro expansion
|
||||||
|
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||||
|
// if it is required in either cmd or proxy keys
|
||||||
|
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||||
|
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||||
|
if cmdHasPort || proxyHasPort { // either has it
|
||||||
|
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add PORT macro and substitute it
|
||||||
|
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||||
|
macroSlug := "${PORT}"
|
||||||
|
macroStr := fmt.Sprintf("%v", nextPort)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute PORT in metadata
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
var err error
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure there are no unknown macros that have not been replaced
|
||||||
|
fieldMap := map[string]string{
|
||||||
|
"cmd": modelConfig.Cmd,
|
||||||
|
"cmdStop": modelConfig.CmdStop,
|
||||||
|
"proxy": modelConfig.Proxy,
|
||||||
|
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||||
|
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, fieldValue := range fieldMap {
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
|
continue // this is ok, has to be replaced by process later
|
||||||
|
}
|
||||||
|
// Reserved macros are always valid (they should have been substituted already)
|
||||||
|
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||||
|
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
// Any other macro is unknown
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for unknown macros in metadata
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
// check that members are all unique in the groups
|
||||||
|
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
// Check for duplicates within this group
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
// Check if member is used in another group
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clean up hooks preload
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
var toPreload []string
|
||||||
|
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if real, found := config.RealModelName(modelID); found {
|
||||||
|
toPreload = append(toPreload, real)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewrites the yaml to include a default group with any orphaned models
|
||||||
|
func AddDefaultGroupToConfig(config Config) Config {
|
||||||
|
|
||||||
|
if config.Groups == nil {
|
||||||
|
config.Groups = make(map[string]GroupConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultGroup := GroupConfig{
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{},
|
||||||
|
}
|
||||||
|
// if groups is empty, create a default group and put
|
||||||
|
// all models into it
|
||||||
|
if len(config.Groups) == 0 {
|
||||||
|
for modelName := range config.Models {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// iterate over existing group members and add non-grouped models into the default group
|
||||||
|
for modelName := range config.Models {
|
||||||
|
foundModel := false
|
||||||
|
found:
|
||||||
|
// search for the model in existing groups
|
||||||
|
for _, groupConfig := range config.Groups {
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if member == modelName {
|
||||||
|
foundModel = true
|
||||||
|
break found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundModel {
|
||||||
|
defaultGroup.Members = append(defaultGroup.Members, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(defaultGroup.Members) // make consistent ordering for testing
|
||||||
|
config.Groups[DEFAULT_GROUP_ID] = defaultGroup
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle trailing backslashes by replacing with space
|
||||||
|
if strings.HasSuffix(trimmed, "\\") {
|
||||||
|
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put it back together
|
||||||
|
cmdStr = strings.Join(cleanedLines, "\n")
|
||||||
|
|
||||||
|
// Split the command into arguments
|
||||||
|
var args []string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
args = shlex.Windows.Split(cmdStr)
|
||||||
|
} else {
|
||||||
|
args = shlex.Posix.Split(cmdStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the command is not empty
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty command")
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripComments(cmdStr string) string {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
return strings.Join(cleanedLines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateMacro validates macro name and value constraints
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
if len(name) >= 64 {
|
||||||
|
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||||
|
}
|
||||||
|
if !macroNameRegex.MatchString(name) {
|
||||||
|
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that value is a scalar type
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
if len(v) >= 1024 {
|
||||||
|
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||||
|
}
|
||||||
|
// Check for self-reference
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||||
|
// These types are allowed
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "PORT", "MODEL_ID":
|
||||||
|
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||||
|
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Scalar types don't contain macros
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
|
// Test a command with spaces and newlines
|
||||||
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
|
-a "double quotes" \
|
||||||
|
--arg2 'single quotes'
|
||||||
|
-s
|
||||||
|
# comment 1
|
||||||
|
--arg3 123 \
|
||||||
|
|
||||||
|
# comment 2
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
|
||||||
|
|
||||||
|
# this will get stripped out as well as the white space above
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"--arg2", "single quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", `"string in string"`,
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
|
// Test an empty command
|
||||||
|
args, err = SanitizeCommand("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the default values are automatically set for global, model and group configurations
|
||||||
|
// after loading the configuration
|
||||||
|
func TestConfig_DefaultValuesPosix(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "info", config.LogLevel)
|
||||||
|
|
||||||
|
// Test default group exists
|
||||||
|
defaultGroup, exists := config.Groups["(default)"]
|
||||||
|
assert.True(t, exists, "default group should exist")
|
||||||
|
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||||
|
assert.Equal(t, true, defaultGroup.Swap)
|
||||||
|
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||||
|
assert.Equal(t, false, defaultGroup.Persistent)
|
||||||
|
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||||
|
}
|
||||||
|
|
||||||
|
model1, exists := config.Models["model1"]
|
||||||
|
assert.True(t, exists, "model1 should exist")
|
||||||
|
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||||
|
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||||
|
assert.Equal(t, "", model1.CmdStop)
|
||||||
|
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||||
|
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||||
|
assert.Equal(t, []string{}, model1.Aliases)
|
||||||
|
assert.Equal(t, []string{}, model1.Env)
|
||||||
|
assert.Equal(t, 0, model1.UnloadAfter)
|
||||||
|
assert.Equal(t, false, model1.Unlisted)
|
||||||
|
assert.Equal(t, "", model1.UseModelName)
|
||||||
|
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// default empty filter exists
|
||||||
|
assert.Equal(t, "", model1.Filters.StripParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_LoadPosix(t *testing.T) {
|
||||||
|
// Create a temporary YAML file for testing
|
||||||
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
hooks:
|
||||||
|
on_startup:
|
||||||
|
preload: ["model1", "model2"]
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
name: "Model 1"
|
||||||
|
description: "This is model 1"
|
||||||
|
aliases:
|
||||||
|
- "m1"
|
||||||
|
- "model-one"
|
||||||
|
env:
|
||||||
|
- "VAR1=value1"
|
||||||
|
- "VAR2=value2"
|
||||||
|
checkEndpoint: "/health"
|
||||||
|
model2:
|
||||||
|
cmd: ${svr-path} --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "m2"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "mthree"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model4:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8082"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
profiles:
|
||||||
|
test:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
forever:
|
||||||
|
exclusive: false
|
||||||
|
persistent: true
|
||||||
|
members:
|
||||||
|
- "model4"
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write temporary file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the config and verify
|
||||||
|
config, err := LoadConfig(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := Config{
|
||||||
|
LogLevel: "info",
|
||||||
|
StartPort: 5800,
|
||||||
|
Macros: MacroList{
|
||||||
|
{"svr-path", "path/to/server"},
|
||||||
|
},
|
||||||
|
Hooks: HooksConfig{
|
||||||
|
OnStartup: HookOnStartup{
|
||||||
|
Preload: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
Aliases: []string{"m1", "model-one"},
|
||||||
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
Name: "Model 1",
|
||||||
|
Description: "This is model 1",
|
||||||
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "path/to/server --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model3": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"mthree"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model4": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8082",
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
"mthree": "model3",
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, config)
|
||||||
|
|
||||||
|
realname, found := config.RealModelName("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", realname)
|
||||||
|
}
|
||||||
@@ -0,0 +1,763 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
group2:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
`
|
||||||
|
// Load the config and verify
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
|
||||||
|
// a Contains as order of the map is not guaranteed
|
||||||
|
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
aliases:
|
||||||
|
- m1
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
aliases:
|
||||||
|
- m1
|
||||||
|
- m2
|
||||||
|
`
|
||||||
|
// Load the config and verify
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
|
||||||
|
// this is a contains because it could be `model1` or `model2` depending on the order
|
||||||
|
// go decided on the order of the map
|
||||||
|
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_FindConfig(t *testing.T) {
|
||||||
|
|
||||||
|
// TODO?
|
||||||
|
// make make this shared between the different tests
|
||||||
|
config := &Config{
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": {
|
||||||
|
Cmd: "python model1.py",
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
Aliases: []string{"m1", "model-one"},
|
||||||
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "python model2.py",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2", "model-two"},
|
||||||
|
Env: []string{"VAR3=value3", "VAR4=value4"},
|
||||||
|
CheckEndpoint: "/status",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckTimeout: 10,
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test finding a model by its name
|
||||||
|
modelConfig, modelId, found := config.FindConfig("model1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", modelId)
|
||||||
|
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||||
|
|
||||||
|
// Test finding a model by its alias
|
||||||
|
modelConfig, modelId, found = config.FindConfig("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", modelId)
|
||||||
|
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||||
|
|
||||||
|
// Test finding a model that does not exist
|
||||||
|
modelConfig, modelId, found = config.FindConfig("model3")
|
||||||
|
assert.False(t, found)
|
||||||
|
assert.Equal(t, "", modelId)
|
||||||
|
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("Default Port Ranges", func(t *testing.T) {
|
||||||
|
content := ``
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
})
|
||||||
|
t.Run("User specific port ranges", func(t *testing.T) {
|
||||||
|
content := `startPort: 1000`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 1000, config.StartPort)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Invalid start port", func(t *testing.T) {
|
||||||
|
content := `startPort: abcd`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("start port must be greater than 1", func(t *testing.T) {
|
||||||
|
content := `startPort: -99`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Automatic port assignments", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 5800
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: svr --port ${PORT}
|
||||||
|
model2:
|
||||||
|
cmd: svr --port ${PORT}
|
||||||
|
proxy: "http://172.11.22.33:${PORT}"
|
||||||
|
model3:
|
||||||
|
cmd: svr --port 1999
|
||||||
|
proxy: "http://1.2.3.4:1999"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
||||||
|
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
||||||
|
|
||||||
|
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
||||||
|
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
||||||
|
|
||||||
|
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
||||||
|
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: svr --port 111
|
||||||
|
`
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Equal(t, "model model1: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroReplacement(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
argOne: "--arg1"
|
||||||
|
argTwo: "--arg2"
|
||||||
|
autoPort: "--port ${PORT}"
|
||||||
|
overriddenByModelMacro: failed
|
||||||
|
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
macros:
|
||||||
|
overriddenByModelMacro: success
|
||||||
|
cmd: |
|
||||||
|
${svr-path} ${argTwo}
|
||||||
|
# the automatic ${PORT} is replaced
|
||||||
|
${autoPort}
|
||||||
|
${argOne}
|
||||||
|
--arg3 three
|
||||||
|
--overridden ${overriddenByModelMacro}
|
||||||
|
cmdStop: |
|
||||||
|
/path/to/stop.sh --port ${PORT} ${argTwo}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
|
||||||
|
|
||||||
|
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroReservedNames(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config string
|
||||||
|
expectedError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "global macro named PORT",
|
||||||
|
config: `
|
||||||
|
macros:
|
||||||
|
PORT: "1111"
|
||||||
|
`,
|
||||||
|
expectedError: "macro name 'PORT' is reserved",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "global macro named MODEL_ID",
|
||||||
|
config: `
|
||||||
|
macros:
|
||||||
|
MODEL_ID: model1
|
||||||
|
`,
|
||||||
|
expectedError: "macro name 'MODEL_ID' is reserved",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model macro named PORT",
|
||||||
|
config: `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
macros:
|
||||||
|
PORT: 1111
|
||||||
|
`,
|
||||||
|
expectedError: "model model1: macro name 'PORT' is reserved",
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "model macro named MODEL_ID",
|
||||||
|
config: `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
macros:
|
||||||
|
MODEL_ID: model1
|
||||||
|
`,
|
||||||
|
expectedError: "model model1: macro name 'MODEL_ID' is reserved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(tt.config))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Equal(t, tt.expectedError, err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
field string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unknown macro in cmd",
|
||||||
|
field: "cmd",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: |
|
||||||
|
${svr-path} --port ${PORT}
|
||||||
|
${unknownMacro}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown macro in cmdStop",
|
||||||
|
field: "cmdStop",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
|
cmdStop: "kill ${unknownMacro}"
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown macro in proxy",
|
||||||
|
field: "proxy",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
|
proxy: "http://${unknownMacro}:${PORT}"
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown macro in checkEndpoint",
|
||||||
|
field: "checkEndpoint",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
|
checkEndpoint: "http://localhost:${unknownMacro}/health"
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown macro in filters.stripParams",
|
||||||
|
field: "filters.stripParams",
|
||||||
|
content: `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: "${svr-path} --port ${PORT}"
|
||||||
|
filters:
|
||||||
|
stripParams: "model,${unknownMacro}"
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
|
||||||
|
}
|
||||||
|
//t.Log(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestStripComments(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no comments",
|
||||||
|
input: "echo hello\necho world",
|
||||||
|
expected: "echo hello\necho world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single comment line",
|
||||||
|
input: "# this is a comment\necho hello",
|
||||||
|
expected: "echo hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple comment lines",
|
||||||
|
input: "# comment 1\necho hello\n# comment 2\necho world",
|
||||||
|
expected: "echo hello\necho world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comment with spaces",
|
||||||
|
input: " # indented comment\necho hello",
|
||||||
|
expected: "echo hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty lines preserved",
|
||||||
|
input: "echo hello\n\necho world",
|
||||||
|
expected: "echo hello\n\necho world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only comments",
|
||||||
|
input: "# comment 1\n# comment 2",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := StripComments(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("StripComments() = %q, expected %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroInCommentStrippedBeforeExpansion(t *testing.T) {
|
||||||
|
// Test case that reproduces the original bug where a macro in a comment
|
||||||
|
// would get expanded and cause the comment text to be included in the command
|
||||||
|
content := `
|
||||||
|
startPort: 9990
|
||||||
|
macros:
|
||||||
|
"latest-llama": >
|
||||||
|
/user/llama.cpp/build/bin/llama-server
|
||||||
|
--port ${PORT}
|
||||||
|
|
||||||
|
models:
|
||||||
|
"test-model":
|
||||||
|
cmd: |
|
||||||
|
# ${latest-llama} is a macro that is defined above
|
||||||
|
${latest-llama}
|
||||||
|
--model /path/to/model.gguf
|
||||||
|
-ngl 99
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Get the sanitized command
|
||||||
|
sanitizedCmd, err := SanitizeCommand(config.Models["test-model"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Join the command for easier inspection
|
||||||
|
cmdStr := strings.Join(sanitizedCmd, " ")
|
||||||
|
|
||||||
|
// Verify that comment text is NOT present in the final command as separate arguments
|
||||||
|
commentWords := []string{"is", "macro", "that", "defined", "above"}
|
||||||
|
for _, word := range commentWords {
|
||||||
|
found := slices.Contains(sanitizedCmd, word)
|
||||||
|
assert.False(t, found, "Comment text '%s' should not be present as a separate argument in final command", word)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the actual command components ARE present
|
||||||
|
expectedParts := []string{
|
||||||
|
"/user/llama.cpp/build/bin/llama-server",
|
||||||
|
"--port",
|
||||||
|
"9990",
|
||||||
|
"--model",
|
||||||
|
"/path/to/model.gguf",
|
||||||
|
"-ngl",
|
||||||
|
"99",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range expectedParts {
|
||||||
|
assert.Contains(t, cmdStr, part, "Expected command part '%s' not found in final command", part)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the server path appears exactly once (not duplicated due to macro expansion)
|
||||||
|
serverPath := "/user/llama.cpp/build/bin/llama-server"
|
||||||
|
count := strings.Count(cmdStr, serverPath)
|
||||||
|
assert.Equal(t, 1, count, "Expected exactly 1 occurrence of server path, found %d", count)
|
||||||
|
|
||||||
|
// Verify the expected final command structure
|
||||||
|
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
|
||||||
|
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroModelId(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 9000
|
||||||
|
macros:
|
||||||
|
"docker-llama": docker run --name ${MODEL_ID} -p ${PORT}:8080 docker_img
|
||||||
|
"docker-stop": docker stop ${MODEL_ID}
|
||||||
|
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
|
||||||
|
|
||||||
|
model2:
|
||||||
|
cmd: ${docker-llama}
|
||||||
|
cmdStop: ${docker-stop}
|
||||||
|
|
||||||
|
author/model:F16:
|
||||||
|
cmd: /path/to/server -p ${PORT} -hf ${MODEL_ID}
|
||||||
|
cmdStop: stop
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
|
||||||
|
|
||||||
|
dockerStopMacro, found := config.Macros.Get("docker-stop")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "docker stop ${MODEL_ID}", dockerStopMacro)
|
||||||
|
|
||||||
|
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "docker run --name model2 -p 9002:8080 docker_img", strings.Join(sanitizedCmd2, " "))
|
||||||
|
|
||||||
|
sanitizedCmdStop, err := SanitizeCommand(config.Models["model2"].CmdStop)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "docker stop model2", strings.Join(sanitizedCmdStop, " "))
|
||||||
|
|
||||||
|
sanitizedCmd3, err := SanitizeCommand(config.Models["author/model:F16"].Cmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_TypedMacrosInMetadata(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
PORT_NUM: 10001
|
||||||
|
TEMP: 0.7
|
||||||
|
ENABLED: true
|
||||||
|
NAME: "llama model"
|
||||||
|
CTX: 16384
|
||||||
|
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
metadata:
|
||||||
|
port: ${PORT_NUM}
|
||||||
|
temperature: ${TEMP}
|
||||||
|
enabled: ${ENABLED}
|
||||||
|
model_name: ${NAME}
|
||||||
|
context: ${CTX}
|
||||||
|
note: "Running on port ${PORT_NUM} with temp ${TEMP} and context ${CTX}"
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
meta := config.Models["test-model"].Metadata
|
||||||
|
assert.NotNil(t, meta)
|
||||||
|
|
||||||
|
// Verify direct substitution preserves types
|
||||||
|
assert.Equal(t, 10001, meta["port"])
|
||||||
|
assert.Equal(t, 0.7, meta["temperature"])
|
||||||
|
assert.Equal(t, true, meta["enabled"])
|
||||||
|
assert.Equal(t, "llama model", meta["model_name"])
|
||||||
|
assert.Equal(t, 16384, meta["context"])
|
||||||
|
|
||||||
|
// Verify string interpolation converts to string
|
||||||
|
assert.Equal(t, "Running on port 10001 with temp 0.7 and context 16384", meta["note"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_NestedStructuresInMetadata(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
TEMP: 0.7
|
||||||
|
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
metadata:
|
||||||
|
config:
|
||||||
|
port: ${PORT}
|
||||||
|
temperature: ${TEMP}
|
||||||
|
tags: ["model:${MODEL_ID}", "port:${PORT}"]
|
||||||
|
nested:
|
||||||
|
deep:
|
||||||
|
value: ${TEMP}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
meta := config.Models["test-model"].Metadata
|
||||||
|
assert.NotNil(t, meta)
|
||||||
|
|
||||||
|
// Verify nested objects
|
||||||
|
configMap := meta["config"].(map[string]any)
|
||||||
|
assert.Equal(t, 10000, configMap["port"])
|
||||||
|
assert.Equal(t, 0.7, configMap["temperature"])
|
||||||
|
|
||||||
|
// Verify arrays
|
||||||
|
tags := meta["tags"].([]any)
|
||||||
|
assert.Equal(t, "model:test-model", tags[0])
|
||||||
|
assert.Equal(t, "port:10000", tags[1])
|
||||||
|
|
||||||
|
// Verify deeply nested structures
|
||||||
|
nested := meta["nested"].(map[string]any)
|
||||||
|
deep := nested["deep"].(map[string]any)
|
||||||
|
assert.Equal(t, 0.7, deep["value"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelLevelMacroPrecedenceInMetadata(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
TEMP: 0.5
|
||||||
|
GLOBAL_VAL: "global"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
macros:
|
||||||
|
TEMP: 0.9
|
||||||
|
LOCAL_VAL: "local"
|
||||||
|
metadata:
|
||||||
|
temperature: ${TEMP}
|
||||||
|
global: ${GLOBAL_VAL}
|
||||||
|
local: ${LOCAL_VAL}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
meta := config.Models["test-model"].Metadata
|
||||||
|
assert.NotNil(t, meta)
|
||||||
|
|
||||||
|
// Model-level macro should override global
|
||||||
|
assert.Equal(t, 0.9, meta["temperature"])
|
||||||
|
// Global macro should be accessible
|
||||||
|
assert.Equal(t, "global", meta["global"])
|
||||||
|
// Model-level macro should be accessible
|
||||||
|
assert.Equal(t, "local", meta["local"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UnknownMacroInMetadata(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
metadata:
|
||||||
|
value: ${UNKNOWN_MACRO}
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "test-model")
|
||||||
|
assert.Contains(t, err.Error(), "UNKNOWN_MACRO")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_InvalidMacroType(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
INVALID:
|
||||||
|
nested: value
|
||||||
|
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "INVALID")
|
||||||
|
assert.Contains(t, err.Error(), "must be a scalar type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MacroTypeValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
yaml string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string macro",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
STR: "test"
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "int macro",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
NUM: 42
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "float macro",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
FLOAT: 3.14
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bool macro",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
BOOL: true
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array macro (invalid)",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
ARR: [1, 2, 3]
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "map macro (invalid)",
|
||||||
|
yaml: `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
MAP:
|
||||||
|
key: value
|
||||||
|
models:
|
||||||
|
test-model:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(tt.yaml))
|
||||||
|
if tt.shouldErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,231 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
|
// does not support single quoted strings like in config_posix_test.go
|
||||||
|
args, err := SanitizeCommand(`python model1.py \
|
||||||
|
|
||||||
|
-a "double quotes" \
|
||||||
|
-s
|
||||||
|
--arg3 123 \
|
||||||
|
|
||||||
|
# comment 2
|
||||||
|
--arg4 '"string in string"'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# this will get stripped out as well as the white space above
|
||||||
|
-c "'single quoted'"
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{
|
||||||
|
"python", "model1.py",
|
||||||
|
"-a", "double quotes",
|
||||||
|
"-s",
|
||||||
|
"--arg3", "123",
|
||||||
|
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
||||||
|
"-c", `'single quoted'`,
|
||||||
|
}, args)
|
||||||
|
|
||||||
|
// Test an empty command
|
||||||
|
args, err = SanitizeCommand("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_DefaultValuesWindows(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 120, config.HealthCheckTimeout)
|
||||||
|
assert.Equal(t, 5800, config.StartPort)
|
||||||
|
assert.Equal(t, "info", config.LogLevel)
|
||||||
|
|
||||||
|
// Test default group exists
|
||||||
|
defaultGroup, exists := config.Groups["(default)"]
|
||||||
|
assert.True(t, exists, "default group should exist")
|
||||||
|
if assert.NotNil(t, defaultGroup, "default group should not be nil") {
|
||||||
|
assert.Equal(t, true, defaultGroup.Swap)
|
||||||
|
assert.Equal(t, true, defaultGroup.Exclusive)
|
||||||
|
assert.Equal(t, false, defaultGroup.Persistent)
|
||||||
|
assert.Equal(t, []string{"model1"}, defaultGroup.Members)
|
||||||
|
}
|
||||||
|
|
||||||
|
model1, exists := config.Models["model1"]
|
||||||
|
assert.True(t, exists, "model1 should exist")
|
||||||
|
if assert.NotNil(t, model1, "model1 should not be nil") {
|
||||||
|
assert.Equal(t, "path/to/cmd --port 5800", model1.Cmd) // has the port replaced
|
||||||
|
assert.Equal(t, "taskkill /f /t /pid ${PID}", model1.CmdStop)
|
||||||
|
assert.Equal(t, "http://localhost:5800", model1.Proxy)
|
||||||
|
assert.Equal(t, "/health", model1.CheckEndpoint)
|
||||||
|
assert.Equal(t, []string{}, model1.Aliases)
|
||||||
|
assert.Equal(t, []string{}, model1.Env)
|
||||||
|
assert.Equal(t, 0, model1.UnloadAfter)
|
||||||
|
assert.Equal(t, false, model1.Unlisted)
|
||||||
|
assert.Equal(t, "", model1.UseModelName)
|
||||||
|
assert.Equal(t, 0, model1.ConcurrencyLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// default empty filter exists
|
||||||
|
assert.Equal(t, "", model1.Filters.StripParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_LoadWindows(t *testing.T) {
|
||||||
|
// Create a temporary YAML file for testing
|
||||||
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
svr-path: "path/to/server"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
aliases:
|
||||||
|
- "m1"
|
||||||
|
- "model-one"
|
||||||
|
env:
|
||||||
|
- "VAR1=value1"
|
||||||
|
- "VAR2=value2"
|
||||||
|
checkEndpoint: "/health"
|
||||||
|
model2:
|
||||||
|
cmd: ${svr-path} --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "m2"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model3:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "mthree"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
model4:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8082"
|
||||||
|
checkEndpoint: "/"
|
||||||
|
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
profiles:
|
||||||
|
test:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
swap: true
|
||||||
|
exclusive: false
|
||||||
|
members: ["model2"]
|
||||||
|
forever:
|
||||||
|
exclusive: false
|
||||||
|
persistent: true
|
||||||
|
members:
|
||||||
|
- "model4"
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write temporary file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the config and verify
|
||||||
|
config, err := LoadConfig(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := Config{
|
||||||
|
LogLevel: "info",
|
||||||
|
StartPort: 5800,
|
||||||
|
Macros: MacroList{
|
||||||
|
{"svr-path", "path/to/server"},
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8080",
|
||||||
|
Aliases: []string{"m1", "model-one"},
|
||||||
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "path/to/server --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model3": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"mthree"},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
|
"model4": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||||
|
Proxy: "http://localhost:8082",
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
"mthree": "model3",
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
DEFAULT_GROUP_ID: {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: true,
|
||||||
|
Members: []string{"model1", "model3"},
|
||||||
|
},
|
||||||
|
"group1": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Members: []string{"model2"},
|
||||||
|
},
|
||||||
|
"forever": {
|
||||||
|
Swap: true,
|
||||||
|
Exclusive: false,
|
||||||
|
Persistent: true,
|
||||||
|
Members: []string{"model4"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, config)
|
||||||
|
|
||||||
|
realname, found := config.RealModelName("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", realname)
|
||||||
|
}
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test macro-in-macro basic substitution
|
||||||
|
func TestConfig_MacroInMacroBasic(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"A": "value-A"
|
||||||
|
"B": "prefix-${A}-suffix"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${B}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test LIFO substitution order with 3+ macro levels
|
||||||
|
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"base": "/models"
|
||||||
|
"path": "${base}/llama"
|
||||||
|
"full": "${path}/model.gguf"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: load ${full}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MODEL_ID in global macro used by model
|
||||||
|
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
|
||||||
|
|
||||||
|
models:
|
||||||
|
my-model:
|
||||||
|
cmd: ${podman-llama} -m model.gguf
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test model macro overrides global macro in substitution
|
||||||
|
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"tag": "global"
|
||||||
|
"msg": "value-${tag}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
macros:
|
||||||
|
"tag": "model-level"
|
||||||
|
cmd: echo ${msg}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test self-reference detection error
|
||||||
|
func TestConfig_SelfReferenceDetection(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"recursive": "value-${recursive}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${recursive}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "recursive")
|
||||||
|
assert.Contains(t, err.Error(), "self-reference")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test undefined macro reference error
|
||||||
|
func TestConfig_UndefinedMacroReference(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
startPort: 10000
|
||||||
|
macros:
|
||||||
|
"A": "value-${UNDEFINED}"
|
||||||
|
|
||||||
|
models:
|
||||||
|
test:
|
||||||
|
cmd: echo ${A}
|
||||||
|
proxy: http://localhost:8080
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "UNDEFINED")
|
||||||
|
}
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelConfig struct {
|
||||||
|
Cmd string `yaml:"cmd"`
|
||||||
|
CmdStop string `yaml:"cmdStop"`
|
||||||
|
Proxy string `yaml:"proxy"`
|
||||||
|
Aliases []string `yaml:"aliases"`
|
||||||
|
Env []string `yaml:"env"`
|
||||||
|
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||||
|
UnloadAfter int `yaml:"ttl"`
|
||||||
|
Unlisted bool `yaml:"unlisted"`
|
||||||
|
UseModelName string `yaml:"useModelName"`
|
||||||
|
|
||||||
|
// #179 for /v1/models
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
|
||||||
|
// Limit concurrency of HTTP requests to process
|
||||||
|
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||||
|
|
||||||
|
// Model filters see issue #174
|
||||||
|
Filters ModelFilters `yaml:"filters"`
|
||||||
|
|
||||||
|
// Macros: see #264
|
||||||
|
// Model level macros take precedence over the global macros
|
||||||
|
Macros MacroList `yaml:"macros"`
|
||||||
|
|
||||||
|
// Metadata: see #264
|
||||||
|
// Arbitrary metadata that can be exposed through the API
|
||||||
|
Metadata map[string]any `yaml:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawModelConfig ModelConfig
|
||||||
|
defaults := rawModelConfig{
|
||||||
|
Cmd: "",
|
||||||
|
CmdStop: "",
|
||||||
|
Proxy: "http://localhost:${PORT}",
|
||||||
|
Aliases: []string{},
|
||||||
|
Env: []string{},
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
UnloadAfter: 0,
|
||||||
|
Unlisted: false,
|
||||||
|
UseModelName: "",
|
||||||
|
ConcurrencyLimit: 0,
|
||||||
|
Name: "",
|
||||||
|
Description: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = ModelConfig(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
|
return SanitizeCommand(m.Cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelFilters see issue #174
|
||||||
|
type ModelFilters struct {
|
||||||
|
StripParams string `yaml:"stripParams"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
type rawModelFilters ModelFilters
|
||||||
|
defaults := rawModelFilters{
|
||||||
|
StripParams: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unmarshal(&defaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to unmarshal with the old field name for backwards compatibility
|
||||||
|
if defaults.StripParams == "" {
|
||||||
|
var legacy struct {
|
||||||
|
StripParams string `yaml:"strip_params"`
|
||||||
|
}
|
||||||
|
if legacyErr := unmarshal(&legacy); legacyErr != nil {
|
||||||
|
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
|
||||||
|
}
|
||||||
|
defaults.StripParams = legacy.StripParams
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = ModelFilters(defaults)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||||
|
if f.StripParams == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
params := strings.Split(f.StripParams, ",")
|
||||||
|
cleaned := make([]string, 0, len(params))
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, param := range params {
|
||||||
|
trimmed := strings.TrimSpace(param)
|
||||||
|
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[trimmed] = true
|
||||||
|
cleaned = append(cleaned, trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort cleaned
|
||||||
|
slices.Sort(cleaned)
|
||||||
|
return cleaned, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
|
config := &ModelConfig{
|
||||||
|
Cmd: `python model1.py \
|
||||||
|
--arg1 value1 \
|
||||||
|
--arg2 value2`,
|
||||||
|
}
|
||||||
|
|
||||||
|
args, err := config.SanitizedCommand()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelFilters(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
macros:
|
||||||
|
default_strip: "temperature, top_p"
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
# macros inserted and list is cleaned of duplicates and empty strings
|
||||||
|
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||||
|
# check for strip_params (legacy field name) compatibility
|
||||||
|
legacy:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
filters:
|
||||||
|
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
for modelId, modelConfig := range config.Models {
|
||||||
|
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
|
||||||
|
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
|
||||||
|
sanitized, err := modelConfig.Filters.SanitizedStripParams()
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
// model has been removed
|
||||||
|
// empty strings have been removed
|
||||||
|
// duplicates have been removed
|
||||||
|
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
|
||||||
// Test a command with spaces and newlines
|
|
||||||
args, err := SanitizeCommand(`python model1.py \
|
|
||||||
-a "double quotes" \
|
|
||||||
--arg2 'single quotes'
|
|
||||||
-s
|
|
||||||
# comment 1
|
|
||||||
--arg3 123 \
|
|
||||||
|
|
||||||
# comment 2
|
|
||||||
--arg4 '"string in string"'
|
|
||||||
|
|
||||||
|
|
||||||
# this will get stripped out as well as the white space above
|
|
||||||
-c "'single quoted'"
|
|
||||||
`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{
|
|
||||||
"python", "model1.py",
|
|
||||||
"-a", "double quotes",
|
|
||||||
"--arg2", "single quotes",
|
|
||||||
"-s",
|
|
||||||
"--arg3", "123",
|
|
||||||
"--arg4", `"string in string"`,
|
|
||||||
"-c", `'single quoted'`,
|
|
||||||
}, args)
|
|
||||||
|
|
||||||
// Test an empty command
|
|
||||||
args, err = SanitizeCommand("")
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, args)
|
|
||||||
}
|
|
||||||
@@ -1,333 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
|
||||||
// Create a temporary YAML file for testing
|
|
||||||
tempDir, err := os.MkdirTemp("", "test-config")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(tempDir)
|
|
||||||
|
|
||||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8080"
|
|
||||||
aliases:
|
|
||||||
- "m1"
|
|
||||||
- "model-one"
|
|
||||||
env:
|
|
||||||
- "VAR1=value1"
|
|
||||||
- "VAR2=value2"
|
|
||||||
checkEndpoint: "/health"
|
|
||||||
model2:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
aliases:
|
|
||||||
- "m2"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
model3:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
aliases:
|
|
||||||
- "mthree"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
model4:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8082"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
|
|
||||||
healthCheckTimeout: 15
|
|
||||||
profiles:
|
|
||||||
test:
|
|
||||||
- model1
|
|
||||||
- model2
|
|
||||||
groups:
|
|
||||||
group1:
|
|
||||||
swap: true
|
|
||||||
exclusive: false
|
|
||||||
members: ["model2"]
|
|
||||||
forever:
|
|
||||||
exclusive: false
|
|
||||||
persistent: true
|
|
||||||
members:
|
|
||||||
- "model4"
|
|
||||||
`
|
|
||||||
|
|
||||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write temporary file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the config and verify
|
|
||||||
config, err := LoadConfig(tempFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := Config{
|
|
||||||
StartPort: 5800,
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8080",
|
|
||||||
Aliases: []string{"m1", "model-one"},
|
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
},
|
|
||||||
"model2": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8081",
|
|
||||||
Aliases: []string{"m2"},
|
|
||||||
Env: nil,
|
|
||||||
CheckEndpoint: "/",
|
|
||||||
},
|
|
||||||
"model3": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8081",
|
|
||||||
Aliases: []string{"mthree"},
|
|
||||||
Env: nil,
|
|
||||||
CheckEndpoint: "/",
|
|
||||||
},
|
|
||||||
"model4": {
|
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
|
||||||
Proxy: "http://localhost:8082",
|
|
||||||
CheckEndpoint: "/",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Profiles: map[string][]string{
|
|
||||||
"test": {"model1", "model2"},
|
|
||||||
},
|
|
||||||
aliases: map[string]string{
|
|
||||||
"m1": "model1",
|
|
||||||
"model-one": "model1",
|
|
||||||
"m2": "model2",
|
|
||||||
"mthree": "model3",
|
|
||||||
},
|
|
||||||
Groups: map[string]GroupConfig{
|
|
||||||
DEFAULT_GROUP_ID: {
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: true,
|
|
||||||
Members: []string{"model1", "model3"},
|
|
||||||
},
|
|
||||||
"group1": {
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: false,
|
|
||||||
Members: []string{"model2"},
|
|
||||||
},
|
|
||||||
"forever": {
|
|
||||||
Swap: true,
|
|
||||||
Exclusive: false,
|
|
||||||
Persistent: true,
|
|
||||||
Members: []string{"model4"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, expected, config)
|
|
||||||
|
|
||||||
realname, found := config.RealModelName("m1")
|
|
||||||
assert.True(t, found)
|
|
||||||
assert.Equal(t, "model1", realname)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_GroupMemberIsUnique(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8080"
|
|
||||||
model2:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
model3:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
|
|
||||||
healthCheckTimeout: 15
|
|
||||||
groups:
|
|
||||||
group1:
|
|
||||||
swap: true
|
|
||||||
exclusive: false
|
|
||||||
members: ["model2"]
|
|
||||||
group2:
|
|
||||||
swap: true
|
|
||||||
exclusive: false
|
|
||||||
members: ["model2"]
|
|
||||||
`
|
|
||||||
// Load the config and verify
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
|
|
||||||
// a Contains as order of the map is not guaranteed
|
|
||||||
assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_ModelAliasesAreUnique(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8080"
|
|
||||||
aliases:
|
|
||||||
- m1
|
|
||||||
model2:
|
|
||||||
cmd: path/to/cmd --arg1 one
|
|
||||||
proxy: "http://localhost:8081"
|
|
||||||
checkEndpoint: "/"
|
|
||||||
aliases:
|
|
||||||
- m1
|
|
||||||
- m2
|
|
||||||
`
|
|
||||||
// Load the config and verify
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
|
|
||||||
// this is a contains because it could be `model1` or `model2` depending on the order
|
|
||||||
// go decided on the order of the map
|
|
||||||
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
|
||||||
config := &ModelConfig{
|
|
||||||
Cmd: `python model1.py \
|
|
||||||
--arg1 value1 \
|
|
||||||
--arg2 value2`,
|
|
||||||
}
|
|
||||||
|
|
||||||
args, err := config.SanitizedCommand()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_FindConfig(t *testing.T) {
|
|
||||||
|
|
||||||
// TODO?
|
|
||||||
// make make this shared between the different tests
|
|
||||||
config := &Config{
|
|
||||||
Models: map[string]ModelConfig{
|
|
||||||
"model1": {
|
|
||||||
Cmd: "python model1.py",
|
|
||||||
Proxy: "http://localhost:8080",
|
|
||||||
Aliases: []string{"m1", "model-one"},
|
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
},
|
|
||||||
"model2": {
|
|
||||||
Cmd: "python model2.py",
|
|
||||||
Proxy: "http://localhost:8081",
|
|
||||||
Aliases: []string{"m2", "model-two"},
|
|
||||||
Env: []string{"VAR3=value3", "VAR4=value4"},
|
|
||||||
CheckEndpoint: "/status",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
HealthCheckTimeout: 10,
|
|
||||||
aliases: map[string]string{
|
|
||||||
"m1": "model1",
|
|
||||||
"model-one": "model1",
|
|
||||||
"m2": "model2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test finding a model by its name
|
|
||||||
modelConfig, modelId, found := config.FindConfig("model1")
|
|
||||||
assert.True(t, found)
|
|
||||||
assert.Equal(t, "model1", modelId)
|
|
||||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
|
||||||
|
|
||||||
// Test finding a model by its alias
|
|
||||||
modelConfig, modelId, found = config.FindConfig("m1")
|
|
||||||
assert.True(t, found)
|
|
||||||
assert.Equal(t, "model1", modelId)
|
|
||||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
|
||||||
|
|
||||||
// Test finding a model that does not exist
|
|
||||||
modelConfig, modelId, found = config.FindConfig("model3")
|
|
||||||
assert.False(t, found)
|
|
||||||
assert.Equal(t, "", modelId)
|
|
||||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_AutomaticPortAssignments(t *testing.T) {
|
|
||||||
|
|
||||||
t.Run("Default Port Ranges", func(t *testing.T) {
|
|
||||||
content := ``
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 5800, config.StartPort)
|
|
||||||
})
|
|
||||||
t.Run("User specific port ranges", func(t *testing.T) {
|
|
||||||
content := `startPort: 1000`
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 1000, config.StartPort)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Invalid start port", func(t *testing.T) {
|
|
||||||
content := `startPort: abcd`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("start port must be greater than 1", func(t *testing.T) {
|
|
||||||
content := `startPort: -99`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Automatic port assignments", func(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
startPort: 5800
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: svr --port ${PORT}
|
|
||||||
model2:
|
|
||||||
cmd: svr --port ${PORT}
|
|
||||||
proxy: "http://172.11.22.33:${PORT}"
|
|
||||||
model3:
|
|
||||||
cmd: svr --port 1999
|
|
||||||
proxy: "http://1.2.3.4:1999"
|
|
||||||
`
|
|
||||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 5800, config.StartPort)
|
|
||||||
assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd)
|
|
||||||
assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy)
|
|
||||||
|
|
||||||
assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd)
|
|
||||||
assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy)
|
|
||||||
|
|
||||||
assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd)
|
|
||||||
assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy)
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) {
|
|
||||||
content := `
|
|
||||||
models:
|
|
||||||
model1:
|
|
||||||
cmd: svr --port 111
|
|
||||||
`
|
|
||||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
|
||||||
assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConfig_SanitizeCommand(t *testing.T) {
|
|
||||||
// does not support single quoted strings like in config_posix_test.go
|
|
||||||
args, err := SanitizeCommand(`python model1.py \
|
|
||||||
|
|
||||||
-a "double quotes" \
|
|
||||||
-s
|
|
||||||
--arg3 123 \
|
|
||||||
|
|
||||||
# comment 2
|
|
||||||
--arg4 '"string in string"'
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# this will get stripped out as well as the white space above
|
|
||||||
-c "'single quoted'"
|
|
||||||
`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{
|
|
||||||
"python", "model1.py",
|
|
||||||
"-a", "double quotes",
|
|
||||||
"-s",
|
|
||||||
"--arg3", "123",
|
|
||||||
"--arg4", "'string in string'", // this is a little weird but the lexer says so...?
|
|
||||||
"-c", `'single quoted'`,
|
|
||||||
}, args)
|
|
||||||
|
|
||||||
// Test an empty command
|
|
||||||
args, err = SanitizeCommand("")
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, args)
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Custom discard writer that implements http.ResponseWriter but just discards everything
|
||||||
|
type DiscardWriter struct {
|
||||||
|
header http.Header
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) Header() http.Header {
|
||||||
|
if w.header == nil {
|
||||||
|
w.header = make(http.Header)
|
||||||
|
}
|
||||||
|
return w.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) Write(data []byte) (int, error) {
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *DiscardWriter) WriteHeader(code int) {
|
||||||
|
w.status = code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Satisfy the http.Flusher interface for streaming responses
|
||||||
|
func (w *DiscardWriter) Flush() {}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
// package level registry of the different event types
|
||||||
|
|
||||||
|
const ProcessStateChangeEventID = 0x01
|
||||||
|
const ChatCompletionStatsEventID = 0x02
|
||||||
|
const ConfigFileChangedEventID = 0x03
|
||||||
|
const LogDataEventID = 0x04
|
||||||
|
const TokenMetricsEventID = 0x05
|
||||||
|
const ModelPreloadedEventID = 0x06
|
||||||
|
|
||||||
|
type ProcessStateChangeEvent struct {
|
||||||
|
ProcessName string
|
||||||
|
NewState ProcessState
|
||||||
|
OldState ProcessState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ProcessStateChangeEvent) Type() uint32 {
|
||||||
|
return ProcessStateChangeEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStats struct {
|
||||||
|
TokensGenerated int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ChatCompletionStats) Type() uint32 {
|
||||||
|
return ChatCompletionStatsEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReloadingState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReloadingStateStart ReloadingState = iota
|
||||||
|
ReloadingStateEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfigFileChangedEvent struct {
|
||||||
|
ReloadingState ReloadingState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ConfigFileChangedEvent) Type() uint32 {
|
||||||
|
return ConfigFileChangedEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogDataEvent struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e LogDataEvent) Type() uint32 {
|
||||||
|
return LogDataEventID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelPreloadedEvent struct {
|
||||||
|
ModelName string
|
||||||
|
Success bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ModelPreloadedEvent) Type() uint32 {
|
||||||
|
return ModelPreloadedEventID
|
||||||
|
}
|
||||||
@@ -9,12 +9,15 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
nextTestPort int = 12000
|
nextTestPort int = 12000
|
||||||
portMutex sync.Mutex
|
portMutex sync.Mutex
|
||||||
testLogger = NewLogMonitorWriter(os.Stdout)
|
testLogger = NewLogMonitorWriter(os.Stdout)
|
||||||
|
simpleResponderPath = getSimpleResponderPath()
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check if the binary exists
|
// Check if the binary exists
|
||||||
@@ -63,17 +66,21 @@ func getTestPort() int {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
|
||||||
binaryPath := getSimpleResponderPath()
|
// Create a YAML string with just the values we want to set
|
||||||
|
yamlStr := fmt.Sprintf(`
|
||||||
|
cmd: '%s --port %d --silent --respond %s'
|
||||||
|
proxy: "http://127.0.0.1:%d"
|
||||||
|
`, simpleResponderPath, port, expectedMessage, port)
|
||||||
|
|
||||||
// Create a process configuration
|
var cfg config.ModelConfig
|
||||||
return ModelConfig{
|
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
|
||||||
Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
|
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
|
||||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
}
|
}
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 15 KiB |
@@ -1,14 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>llama-swap</title>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<h1>llama-swap</h1>
|
|
||||||
<p>
|
|
||||||
<a href="/logs">view logs</a> | <a href="/upstream">configured models</a> | <a href="https://github.com/mostlygeek/llama-swap">github</a>
|
|
||||||
</p>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,259 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Logs</title>
|
|
||||||
<style>
|
|
||||||
body {
|
|
||||||
margin: 0;
|
|
||||||
height: 100vh;
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
font-family: "Courier New", Courier, monospace;
|
|
||||||
}
|
|
||||||
.log-container {
|
|
||||||
display: flex;
|
|
||||||
flex: 1;
|
|
||||||
gap: 0.5em;
|
|
||||||
margin: 0.5em;
|
|
||||||
min-height: 0;
|
|
||||||
}
|
|
||||||
.log-column {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
flex: 1;
|
|
||||||
min-width: 0;
|
|
||||||
transition: flex 0.3s ease;
|
|
||||||
}
|
|
||||||
.log-column.minimized {
|
|
||||||
flex: 0.1;
|
|
||||||
max-width: 50px;
|
|
||||||
border: 1px solid #777;
|
|
||||||
color: green;
|
|
||||||
}
|
|
||||||
.log-controls {
|
|
||||||
display: grid;
|
|
||||||
grid-template-columns: 1fr auto;
|
|
||||||
gap: 0.5em;
|
|
||||||
margin-bottom: 0.5em;
|
|
||||||
}
|
|
||||||
.log-controls input {
|
|
||||||
width: 100%;
|
|
||||||
padding: 4px;
|
|
||||||
}
|
|
||||||
.log-controls input:focus {
|
|
||||||
outline: none;
|
|
||||||
}
|
|
||||||
.log-stream {
|
|
||||||
flex: 1;
|
|
||||||
padding: 1em;
|
|
||||||
background: #f4f4f4;
|
|
||||||
overflow-y: auto;
|
|
||||||
white-space: pre-wrap;
|
|
||||||
word-wrap: break-word;
|
|
||||||
min-height: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.regex-error {
|
|
||||||
background-color: #ff0000 !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Make headers clickable and show pointer cursor */
|
|
||||||
h2 {
|
|
||||||
cursor: pointer;
|
|
||||||
user-select: none;
|
|
||||||
margin: 0 0 0.5em 0;
|
|
||||||
padding: 0.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
h2:hover {
|
|
||||||
background-color: rgba(0, 0, 0, 0.05);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Dark mode styles */
|
|
||||||
@media (prefers-color-scheme: dark) {
|
|
||||||
body {
|
|
||||||
background-color: #333;
|
|
||||||
color: #fff;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-stream {
|
|
||||||
background: #444;
|
|
||||||
color: #fff;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-controls input {
|
|
||||||
background: #555;
|
|
||||||
color: #fff;
|
|
||||||
border: 1px solid #777;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-controls button {
|
|
||||||
background: #555;
|
|
||||||
color: #fff;
|
|
||||||
border: 1px solid #777;
|
|
||||||
}
|
|
||||||
|
|
||||||
h2:hover {
|
|
||||||
background-color: rgba(255, 255, 255, 0.1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Hide content when minimized */
|
|
||||||
.log-column.minimized .log-controls,
|
|
||||||
.log-column.minimized .log-stream {
|
|
||||||
display: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-column.minimized h2 {
|
|
||||||
writing-mode: vertical-rl;
|
|
||||||
text-orientation: mixed;
|
|
||||||
transform: rotate(180deg);
|
|
||||||
white-space: nowrap;
|
|
||||||
margin: auto;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="log-container">
|
|
||||||
<div class="log-column">
|
|
||||||
<h2>Proxy Logs</h2>
|
|
||||||
<div class="log-controls">
|
|
||||||
<input type="text" id="proxy-filter-input" placeholder="proxy regex filter">
|
|
||||||
<button id="proxy-clear-button">clear</button>
|
|
||||||
</div>
|
|
||||||
<pre class="log-stream" id="proxy-log-stream">Waiting for proxy logs...</pre>
|
|
||||||
</div>
|
|
||||||
<div class="log-column minimized">
|
|
||||||
<h2>Upstream Logs</h2>
|
|
||||||
<div class="log-controls">
|
|
||||||
<input type="text" id="upstream-filter-input" placeholder="upstream regex filter">
|
|
||||||
<button id="upstream-clear-button">clear</button>
|
|
||||||
</div>
|
|
||||||
<pre class="log-stream" id="upstream-log-stream">Waiting for upstream logs...</pre>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<script>
|
|
||||||
class LogStream {
|
|
||||||
constructor(streamElement, filterInput, clearButton, endpoint) {
|
|
||||||
this.streamElement = streamElement;
|
|
||||||
this.filterInput = filterInput;
|
|
||||||
this.clearButton = clearButton;
|
|
||||||
this.endpoint = endpoint;
|
|
||||||
this.logData = "";
|
|
||||||
this.regexFilter = null;
|
|
||||||
this.eventSource = null;
|
|
||||||
|
|
||||||
this.initialize();
|
|
||||||
}
|
|
||||||
|
|
||||||
initialize() {
|
|
||||||
this.filterInput.addEventListener('input', () => this.updateFilter());
|
|
||||||
this.clearButton.addEventListener('click', () => {
|
|
||||||
this.filterInput.value = "";
|
|
||||||
this.regexFilter = null;
|
|
||||||
this.render();
|
|
||||||
});
|
|
||||||
this.setupEventSource();
|
|
||||||
}
|
|
||||||
|
|
||||||
setupEventSource() {
|
|
||||||
if (typeof(EventSource) === "undefined") {
|
|
||||||
this.logData = "SSE Not supported by this browser.";
|
|
||||||
this.render();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const connect = () => {
|
|
||||||
this.eventSource = new EventSource(this.endpoint);
|
|
||||||
|
|
||||||
this.eventSource.onmessage = (event) => {
|
|
||||||
this.logData += event.data;
|
|
||||||
this.logData = this.logData.slice(-1024 * 100);
|
|
||||||
this.render();
|
|
||||||
};
|
|
||||||
|
|
||||||
this.eventSource.onerror = (err) => {
|
|
||||||
// Close the current connection
|
|
||||||
this.eventSource.close();
|
|
||||||
|
|
||||||
this.logData += "\nConnection lost. Retrying in 5 seconds...\n";
|
|
||||||
this.render();
|
|
||||||
|
|
||||||
// Attempt to reconnect after 5 seconds
|
|
||||||
setTimeout(() => {
|
|
||||||
this.logData += "Attempting to reconnect...\n";
|
|
||||||
this.render();
|
|
||||||
connect();
|
|
||||||
}, 5000);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Initial connection
|
|
||||||
connect();
|
|
||||||
}
|
|
||||||
|
|
||||||
render() {
|
|
||||||
let content = this.logData;
|
|
||||||
|
|
||||||
if (this.regexFilter) {
|
|
||||||
const lines = content.split('\n');
|
|
||||||
const filteredLines = lines.filter(line => this.regexFilter.test(line));
|
|
||||||
content = filteredLines.length > 0 ? filteredLines.join('\n') + '\n' : "";
|
|
||||||
}
|
|
||||||
|
|
||||||
this.streamElement.textContent = content;
|
|
||||||
this.streamElement.scrollTop = this.streamElement.scrollHeight;
|
|
||||||
}
|
|
||||||
|
|
||||||
updateFilter() {
|
|
||||||
const pattern = this.filterInput.value.trim();
|
|
||||||
this.filterInput.classList.remove('regex-error');
|
|
||||||
|
|
||||||
if (!pattern) {
|
|
||||||
this.regexFilter = null;
|
|
||||||
this.render();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
this.regexFilter = new RegExp(pattern);
|
|
||||||
} catch (e) {
|
|
||||||
console.error("Invalid regex pattern:", e);
|
|
||||||
this.regexFilter = null;
|
|
||||||
this.filterInput.classList.add('regex-error');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.render();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize both log streams
|
|
||||||
document.addEventListener('DOMContentLoaded', () => {
|
|
||||||
new LogStream(
|
|
||||||
document.getElementById('proxy-log-stream'),
|
|
||||||
document.getElementById('proxy-filter-input'),
|
|
||||||
document.getElementById('proxy-clear-button'),
|
|
||||||
"/logs/streamSSE/proxy"
|
|
||||||
);
|
|
||||||
|
|
||||||
new LogStream(
|
|
||||||
document.getElementById('upstream-log-stream'),
|
|
||||||
document.getElementById('upstream-filter-input'),
|
|
||||||
document.getElementById('upstream-clear-button'),
|
|
||||||
"/logs/streamSSE/upstream"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Initialize clickable headers
|
|
||||||
document.querySelectorAll('h2').forEach(header => {
|
|
||||||
header.addEventListener('click', () => {
|
|
||||||
const column = header.closest('.log-column');
|
|
||||||
column.classList.toggle('minimized');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import "embed"
|
|
||||||
|
|
||||||
//go:embed html
|
|
||||||
var htmlFiles embed.FS
|
|
||||||
|
|
||||||
func getHTMLFile(path string) ([]byte, error) {
|
|
||||||
return htmlFiles.ReadFile("html/" + path)
|
|
||||||
}
|
|
||||||
@@ -2,10 +2,13 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"container/ring"
|
"container/ring"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogLevel int
|
type LogLevel int
|
||||||
@@ -18,7 +21,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type LogMonitor struct {
|
type LogMonitor struct {
|
||||||
clients map[chan []byte]bool
|
eventbus *event.Dispatcher
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
buffer *ring.Ring
|
buffer *ring.Ring
|
||||||
bufferMu sync.RWMutex
|
bufferMu sync.RWMutex
|
||||||
@@ -37,11 +40,11 @@ func NewLogMonitor() *LogMonitor {
|
|||||||
|
|
||||||
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
|
||||||
return &LogMonitor{
|
return &LogMonitor{
|
||||||
clients: make(map[chan []byte]bool),
|
eventbus: event.NewDispatcherConfig(1000),
|
||||||
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
|
||||||
stdout: stdout,
|
stdout: stdout,
|
||||||
level: LevelInfo,
|
level: LevelInfo,
|
||||||
prefix: "",
|
prefix: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,34 +84,14 @@ func (w *LogMonitor) GetHistory() []byte {
|
|||||||
return history
|
return history
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) Subscribe() chan []byte {
|
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
|
||||||
w.mu.Lock()
|
return event.Subscribe(w.eventbus, func(e LogDataEvent) {
|
||||||
defer w.mu.Unlock()
|
callback(e.Data)
|
||||||
|
})
|
||||||
ch := make(chan []byte, 100)
|
|
||||||
w.clients[ch] = true
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
delete(w.clients, ch)
|
|
||||||
close(ch)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) broadcast(msg []byte) {
|
func (w *LogMonitor) broadcast(msg []byte) {
|
||||||
w.mu.RLock()
|
event.Publish(w.eventbus, LogDataEvent{Data: msg})
|
||||||
defer w.mu.RUnlock()
|
|
||||||
|
|
||||||
for client := range w.clients {
|
|
||||||
select {
|
|
||||||
case client <- msg:
|
|
||||||
default:
|
|
||||||
// If client buffer is full, skip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *LogMonitor) SetPrefix(prefix string) {
|
func (w *LogMonitor) SetPrefix(prefix string) {
|
||||||
|
|||||||
@@ -10,38 +10,29 @@ import (
|
|||||||
func TestLogMonitor(t *testing.T) {
|
func TestLogMonitor(t *testing.T) {
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
|
||||||
// Test subscription
|
// A WaitGroup is used to wait for all the expected writes to complete
|
||||||
client1 := logMonitor.Subscribe()
|
var wg sync.WaitGroup
|
||||||
client2 := logMonitor.Subscribe()
|
|
||||||
|
|
||||||
defer logMonitor.Unsubscribe(client1)
|
|
||||||
defer logMonitor.Unsubscribe(client2)
|
|
||||||
|
|
||||||
client1Messages := make([]byte, 0)
|
client1Messages := make([]byte, 0)
|
||||||
client2Messages := make([]byte, 0)
|
client2Messages := make([]byte, 0)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
defer logMonitor.OnLogData(func(data []byte) {
|
||||||
wg.Add(1)
|
client1Messages = append(client1Messages, data...)
|
||||||
|
wg.Done()
|
||||||
|
})()
|
||||||
|
|
||||||
go func() {
|
defer logMonitor.OnLogData(func(data []byte) {
|
||||||
defer wg.Done()
|
client2Messages = append(client2Messages, data...)
|
||||||
for {
|
wg.Done()
|
||||||
select {
|
})()
|
||||||
case data := <-client1:
|
|
||||||
client1Messages = append(client1Messages, data...)
|
wg.Add(6) // 2 x 3 writes
|
||||||
case data := <-client2:
|
|
||||||
client2Messages = append(client2Messages, data...)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
logMonitor.Write([]byte("1"))
|
logMonitor.Write([]byte("1"))
|
||||||
logMonitor.Write([]byte("2"))
|
logMonitor.Write([]byte("2"))
|
||||||
logMonitor.Write([]byte("3"))
|
logMonitor.Write([]byte("3"))
|
||||||
|
|
||||||
// Wait for the goroutine to finish
|
// wait for all writes to complete
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// Check the buffer
|
// Check the buffer
|
||||||
|
|||||||
@@ -0,0 +1,184 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MetricsRecorder struct {
|
||||||
|
metricsMonitor *MetricsMonitor
|
||||||
|
realModelName string
|
||||||
|
// isStreaming bool
|
||||||
|
startTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
||||||
|
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
|
if requestedModel == "" {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := &MetricsResponseWriter{
|
||||||
|
ResponseWriter: c.Writer,
|
||||||
|
metricsRecorder: &MetricsRecorder{
|
||||||
|
metricsMonitor: pm.metricsMonitor,
|
||||||
|
realModelName: realModelName,
|
||||||
|
startTime: time.Now(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.Writer = writer
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
// check for streaming response
|
||||||
|
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
|
||||||
|
writer.metricsRecorder.processStreamingResponse(writer.body)
|
||||||
|
} else {
|
||||||
|
writer.metricsRecorder.processNonStreamingResponse(writer.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
|
||||||
|
usage := jsonData.Get("usage")
|
||||||
|
timings := jsonData.Get("timings")
|
||||||
|
if !usage.Exists() && !timings.Exists() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// default values
|
||||||
|
cachedTokens := -1 // unknown or missing data
|
||||||
|
outputTokens := 0
|
||||||
|
inputTokens := 0
|
||||||
|
|
||||||
|
// timings data
|
||||||
|
tokensPerSecond := -1.0
|
||||||
|
promptPerSecond := -1.0
|
||||||
|
durationMs := int(time.Since(rec.startTime).Milliseconds())
|
||||||
|
|
||||||
|
if usage.Exists() {
|
||||||
|
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
||||||
|
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||||
|
if timings.Exists() {
|
||||||
|
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
||||||
|
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
||||||
|
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
||||||
|
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||||
|
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||||
|
|
||||||
|
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
|
||||||
|
cachedTokens = int(cachedValue.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rec.metricsMonitor.addMetrics(TokenMetrics{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: rec.realModelName,
|
||||||
|
CachedTokens: cachedTokens,
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
PromptPerSecond: promptPerSecond,
|
||||||
|
TokensPerSecond: tokensPerSecond,
|
||||||
|
DurationMs: durationMs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
||||||
|
// Iterate **backwards** through the lines looking for the data payload with
|
||||||
|
// usage data
|
||||||
|
lines := bytes.Split(body, []byte("\n"))
|
||||||
|
|
||||||
|
for i := len(lines) - 1; i >= 0; i-- {
|
||||||
|
line := bytes.TrimSpace(lines[i])
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE payload always follows "data:"
|
||||||
|
prefix := []byte("data:")
|
||||||
|
if !bytes.HasPrefix(line, prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[len(prefix):])
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
// [DONE] line itself contains nothing of interest.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.ValidBytes(data) {
|
||||||
|
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
|
||||||
|
return // short circuit if a metric was recorded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON to extract usage information
|
||||||
|
if gjson.ValidBytes(body) {
|
||||||
|
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsResponseWriter captures the entire response for non-streaming
|
||||||
|
type MetricsResponseWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body []byte
|
||||||
|
metricsRecorder *MetricsRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
n, err := w.ResponseWriter.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
w.body = append(w.body, b...)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MetricsResponseWriter) Header() http.Header {
|
||||||
|
return w.ResponseWriter.Header()
|
||||||
|
}
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||||
|
type TokenMetrics struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
CachedTokens int `json:"cache_tokens"`
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
|
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||||
|
DurationMs int `json:"duration_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenMetricsEvent represents a token metrics event
|
||||||
|
type TokenMetricsEvent struct {
|
||||||
|
Metrics TokenMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e TokenMetricsEvent) Type() uint32 {
|
||||||
|
return TokenMetricsEventID // defined in events.go
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsMonitor parses llama-server output for token statistics
|
||||||
|
type MetricsMonitor struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
metrics []TokenMetrics
|
||||||
|
maxMetrics int
|
||||||
|
nextID int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
|
||||||
|
maxMetrics := config.MetricsMaxInMemory
|
||||||
|
if maxMetrics <= 0 {
|
||||||
|
maxMetrics = 1000 // Default fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := &MetricsMonitor{
|
||||||
|
maxMetrics: maxMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
// addMetrics adds a new metric to the collection and publishes an event
|
||||||
|
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
metric.ID = mp.nextID
|
||||||
|
mp.nextID++
|
||||||
|
mp.metrics = append(mp.metrics, metric)
|
||||||
|
if len(mp.metrics) > mp.maxMetrics {
|
||||||
|
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||||
|
}
|
||||||
|
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns a copy of the current metrics
|
||||||
|
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]TokenMetrics, len(mp.metrics))
|
||||||
|
copy(result, mp.metrics)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetricsJSON returns metrics as JSON
|
||||||
|
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
return json.Marshal(mp.metrics)
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@@ -13,6 +14,9 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProcessState string
|
type ProcessState string
|
||||||
@@ -23,9 +27,6 @@ const (
|
|||||||
StateReady ProcessState = ProcessState("ready")
|
StateReady ProcessState = ProcessState("ready")
|
||||||
StateStopping ProcessState = ProcessState("stopping")
|
StateStopping ProcessState = ProcessState("stopping")
|
||||||
|
|
||||||
// failed a health check on start and will not be recovered
|
|
||||||
StateFailed ProcessState = ProcessState("failed")
|
|
||||||
|
|
||||||
// process is shutdown and will not be restarted
|
// process is shutdown and will not be restarted
|
||||||
StateShutdown ProcessState = ProcessState("shutdown")
|
StateShutdown ProcessState = ProcessState("shutdown")
|
||||||
)
|
)
|
||||||
@@ -39,11 +40,14 @@ const (
|
|||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config config.ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
|
|
||||||
// for p.cmd.Wait() select { ... }
|
// PR #155 called to cancel the upstream process
|
||||||
cmdWaitChan chan error
|
cancelUpstream context.CancelFunc
|
||||||
|
|
||||||
|
// closed when command exits
|
||||||
|
cmdWaitChan chan struct{}
|
||||||
|
|
||||||
processLogger *LogMonitor
|
processLogger *LogMonitor
|
||||||
proxyLogger *LogMonitor
|
proxyLogger *LogMonitor
|
||||||
@@ -61,47 +65,40 @@ type Process struct {
|
|||||||
// used to block on multiple start() calls
|
// used to block on multiple start() calls
|
||||||
waitStarting sync.WaitGroup
|
waitStarting sync.WaitGroup
|
||||||
|
|
||||||
// for managing shutdown state
|
|
||||||
shutdownCtx context.Context
|
|
||||||
shutdownCancel context.CancelFunc
|
|
||||||
|
|
||||||
// for managing concurrency limits
|
// for managing concurrency limits
|
||||||
concurrencyLimitSemaphore chan struct{}
|
concurrencyLimitSemaphore chan struct{}
|
||||||
|
|
||||||
// stop timeout waiting for graceful shutdown
|
// used for testing to override the default value
|
||||||
gracefulStopTimeout time.Duration
|
gracefulStopTimeout time.Duration
|
||||||
|
|
||||||
// track that this happened
|
// track the number of failed starts
|
||||||
upstreamWasStoppedWithKill bool
|
failedStartCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
concurrentLimit := 10
|
concurrentLimit := 10
|
||||||
if config.ConcurrencyLimit > 0 {
|
if config.ConcurrencyLimit > 0 {
|
||||||
concurrentLimit = config.ConcurrencyLimit
|
concurrentLimit = config.ConcurrencyLimit
|
||||||
} else {
|
|
||||||
proxyLogger.Debugf("Concurrency limit for model %s not set, defaulting to 10", ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
cmdWaitChan: make(chan error, 1),
|
cancelUpstream: nil,
|
||||||
processLogger: processLogger,
|
processLogger: processLogger,
|
||||||
proxyLogger: proxyLogger,
|
proxyLogger: proxyLogger,
|
||||||
healthCheckTimeout: healthCheckTimeout,
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */
|
||||||
state: StateStopped,
|
state: StateStopped,
|
||||||
shutdownCtx: ctx,
|
|
||||||
shutdownCancel: cancel,
|
|
||||||
|
|
||||||
// concurrency limit
|
// concurrency limit
|
||||||
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||||
|
|
||||||
|
// To be removed when migration over exec.CommandContext is complete
|
||||||
// stop timeout
|
// stop timeout
|
||||||
gracefulStopTimeout: 5 * time.Second,
|
gracefulStopTimeout: 10 * time.Second,
|
||||||
upstreamWasStoppedWithKill: false,
|
cmdWaitChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,6 +131,7 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
|||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||||
|
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||||
return p.state, nil
|
return p.state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,12 +141,12 @@ func isValidTransition(from, to ProcessState) bool {
|
|||||||
case StateStopped:
|
case StateStopped:
|
||||||
return to == StateStarting
|
return to == StateStarting
|
||||||
case StateStarting:
|
case StateStarting:
|
||||||
return to == StateReady || to == StateFailed || to == StateStopping
|
return to == StateReady || to == StateStopping || to == StateStopped
|
||||||
case StateReady:
|
case StateReady:
|
||||||
return to == StateStopping
|
return to == StateStopping
|
||||||
case StateStopping:
|
case StateStopping:
|
||||||
return to == StateStopped || to == StateShutdown
|
return to == StateStopped || to == StateShutdown
|
||||||
case StateFailed, StateShutdown:
|
case StateShutdown:
|
||||||
return false // No transitions allowed from these states
|
return false // No transitions allowed from these states
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -195,40 +193,36 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
p.waitStarting.Add(1)
|
p.waitStarting.Add(1)
|
||||||
defer p.waitStarting.Done()
|
defer p.waitStarting.Done()
|
||||||
|
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||||
|
|
||||||
p.cmd = exec.Command(args[0], args[1:]...)
|
p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
|
||||||
p.cmd.Stdout = p.processLogger
|
p.cmd.Stdout = p.processLogger
|
||||||
p.cmd.Stderr = p.processLogger
|
p.cmd.Stderr = p.processLogger
|
||||||
p.cmd.Env = p.config.Env
|
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||||
|
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||||
|
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||||
|
p.cancelUpstream = ctxCancelUpstream
|
||||||
|
p.cmdWaitChan = make(chan struct{})
|
||||||
|
|
||||||
|
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
|
|
||||||
// Set process state to failed
|
// Set process state to failed
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil {
|
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||||
|
p.state = StateStopped // force it into a stopped state
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||||
err, curState, swapErr,
|
strings.Join(args, " "), err, curState, swapErr,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("start() failed: %v", err)
|
return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture the exit error for later signalling
|
// Capture the exit error for later signalling
|
||||||
go func() {
|
go p.waitForCmd()
|
||||||
exitErr := p.cmd.Wait()
|
|
||||||
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
|
||||||
|
|
||||||
// there is a race condition when SIGKILL is used, p.cmd.Wait() returns, and then
|
|
||||||
// the code below fires, putting an error into cmdWaitChan. This code is to prevent this
|
|
||||||
if p.upstreamWasStoppedWithKill {
|
|
||||||
p.proxyLogger.Debugf("<%s> process was killed, NOT sending exitErr: %v", p.ID, exitErr)
|
|
||||||
p.upstreamWasStoppedWithKill = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
p.cmdWaitChan <- exitErr
|
|
||||||
}()
|
|
||||||
|
|
||||||
// One of three things can happen at this stage:
|
// One of three things can happen at this stage:
|
||||||
// 1. The command exits unexpectedly
|
// 1. The command exits unexpectedly
|
||||||
@@ -244,67 +238,38 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
// a "none" means don't check for health ... I could have picked a better word :facepalm:
|
||||||
if checkEndpoint != "none" {
|
if checkEndpoint != "none" {
|
||||||
// keep default behaviour
|
|
||||||
if checkEndpoint == "" {
|
|
||||||
checkEndpoint = "/health"
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
proxyTo := p.config.Proxy
|
||||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkDeadline, cancelHealthCheck := context.WithDeadline(
|
|
||||||
context.Background(),
|
|
||||||
checkStartTime.Add(maxDuration),
|
|
||||||
)
|
|
||||||
defer cancelHealthCheck()
|
|
||||||
|
|
||||||
loop:
|
|
||||||
// Ready Check loop
|
// Ready Check loop
|
||||||
for {
|
for {
|
||||||
select {
|
currentState := p.CurrentState()
|
||||||
case <-checkDeadline.Done():
|
if currentState != StateStarting {
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
if currentState == StateStopped {
|
||||||
return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState)
|
return fmt.Errorf("upstream command exited prematurely but successfully")
|
||||||
} else {
|
|
||||||
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
|
||||||
}
|
}
|
||||||
case <-p.shutdownCtx.Done():
|
|
||||||
return errors.New("health check interrupted due to shutdown")
|
return errors.New("health check interrupted due to shutdown")
|
||||||
case exitErr := <-p.cmdWaitChan:
|
|
||||||
if exitErr != nil {
|
|
||||||
p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr)
|
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
|
||||||
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID)
|
|
||||||
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
|
|
||||||
return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState)
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("upstream command exited prematurely but successfully")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
|
||||||
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
|
||||||
cancelHealthCheck()
|
|
||||||
break loop
|
|
||||||
} else {
|
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
|
||||||
endTime, _ := checkDeadline.Deadline()
|
|
||||||
ttl := time.Until(endTime)
|
|
||||||
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
|
||||||
} else {
|
|
||||||
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if time.Since(checkStartTime) > maxDuration {
|
||||||
|
p.stopCommand()
|
||||||
|
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.checkHealthEndpoint(healthURL); err == nil {
|
||||||
|
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
|
ttl := time.Until(checkStartTime.Add(maxDuration))
|
||||||
|
p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
<-time.After(p.healthCheckLoopInterval)
|
<-time.After(p.healthCheckLoopInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -335,6 +300,7 @@ func (p *Process) start() error {
|
|||||||
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||||
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||||
} else {
|
} else {
|
||||||
|
p.failedStartCount = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -358,20 +324,13 @@ func (p *Process) StopImmediately() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> Stopping process", p.ID)
|
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
|
||||||
|
|
||||||
// calling Stop() when state is invalid is a no-op
|
|
||||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
||||||
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop the process with a graceful exit timeout
|
p.stopCommand()
|
||||||
p.stopCommand(p.gracefulStopTimeout)
|
|
||||||
|
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
|
||||||
p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
// Shutdown is called when llama-swap is shutting down. It will give a little bit
|
||||||
@@ -379,64 +338,45 @@ func (p *Process) StopImmediately() {
|
|||||||
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
// is in the state of starting, it will cancel it and shut it down. Once a process is in
|
||||||
// the StateShutdown state, it can not be started again.
|
// the StateShutdown state, it can not be started again.
|
||||||
func (p *Process) Shutdown() {
|
func (p *Process) Shutdown() {
|
||||||
p.shutdownCancel()
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
p.stopCommand(p.gracefulStopTimeout)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.stopCommand()
|
||||||
|
// just force it to this state since there is no recovery from shutdown
|
||||||
p.state = StateShutdown
|
p.state = StateShutdown
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
// If it does not exit within 5 seconds, it will send a SIGKILL.
|
||||||
func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
func (p *Process) stopCommand() {
|
||||||
stopStartTime := time.Now()
|
stopStartTime := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
|
if p.cancelUpstream == nil {
|
||||||
defer cancelTimeout()
|
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||||
|
|
||||||
if p.cmd == nil || p.cmd.Process == nil {
|
|
||||||
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.terminateProcess(); err != nil {
|
p.cancelUpstream()
|
||||||
p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
|
<-p.cmdWaitChan
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-sigtermTimeout.Done():
|
|
||||||
p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID)
|
|
||||||
p.upstreamWasStoppedWithKill = true
|
|
||||||
if err := p.cmd.Process.Kill(); err != nil {
|
|
||||||
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
|
|
||||||
}
|
|
||||||
case err := <-p.cmdWaitChan:
|
|
||||||
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
|
|
||||||
// because if we make it here then the cmd has been successfully running and made it
|
|
||||||
// through the health check. There is a possibility that the cmd crashed after the health check
|
|
||||||
// succeeded but that's not a case llama-swap is handling for now.
|
|
||||||
if err != nil {
|
|
||||||
if errno, ok := err.(syscall.Errno); ok {
|
|
||||||
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
|
||||||
} else if exitError, ok := err.(*exec.ExitError); ok {
|
|
||||||
if strings.Contains(exitError.String(), "signal: terminated") {
|
|
||||||
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
|
||||||
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
|
||||||
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
|
||||||
} else {
|
|
||||||
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Timeout: 500 * time.Millisecond,
|
// wait a short time for a tcp connection to be established
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 500 * time.Millisecond,
|
||||||
|
}).DialContext,
|
||||||
|
},
|
||||||
|
|
||||||
|
// give a long time to respond to the health check endpoint
|
||||||
|
// after the connection is established. See issue: 276
|
||||||
|
Timeout: 5000 * time.Millisecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", healthURL, nil)
|
req, err := http.NewRequest("GET", healthURL, nil)
|
||||||
@@ -464,7 +404,7 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// prevent new requests from being made while stopping or irrecoverable
|
// prevent new requests from being made while stopping or irrecoverable
|
||||||
currentState := p.CurrentState()
|
currentState := p.CurrentState()
|
||||||
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
|
if currentState == StateShutdown || currentState == StateStopping {
|
||||||
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -519,6 +459,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Header().Add(k, v)
|
w.Header().Add(k, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||||
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
w.WriteHeader(resp.StatusCode)
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
// faster than io.Copy when streaming
|
// faster than io.Copy when streaming
|
||||||
@@ -546,3 +490,79 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v",
|
||||||
p.ID, r.RequestURI, startDuration, totalTime)
|
p.ID, r.RequestURI, startDuration, totalTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitForCmd waits for the command to exit and handles exit conditions depending on current state
|
||||||
|
func (p *Process) waitForCmd() {
|
||||||
|
exitErr := p.cmd.Wait()
|
||||||
|
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
|
||||||
|
|
||||||
|
if exitErr != nil {
|
||||||
|
if errno, ok := exitErr.(syscall.Errno); ok {
|
||||||
|
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
|
||||||
|
} else if exitError, ok := exitErr.(*exec.ExitError); ok {
|
||||||
|
if strings.Contains(exitError.String(), "signal: terminated") {
|
||||||
|
p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
|
||||||
|
} else if strings.Contains(exitError.String(), "signal: interrupt") {
|
||||||
|
p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if exitErr.Error() != "context canceled" /* this is normal */ {
|
||||||
|
p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, exitErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := p.CurrentState()
|
||||||
|
switch currentState {
|
||||||
|
case StateStopping:
|
||||||
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||||
|
p.state = StateStopped
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||||
|
p.state = StateStopped // force it to be in this state
|
||||||
|
}
|
||||||
|
close(p.cmdWaitChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||||
|
func (p *Process) cmdStopUpstreamProcess() error {
|
||||||
|
p.processLogger.Debugf("<%s> cmdStopUpstreamProcess() initiating graceful stop of upstream process", p.ID)
|
||||||
|
|
||||||
|
// this should never happen ...
|
||||||
|
if p.cmd == nil || p.cmd.Process == nil {
|
||||||
|
p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
|
||||||
|
return fmt.Errorf("<%s> process is nil or cmd is nil, skipping graceful stop", p.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.CmdStop != "" {
|
||||||
|
// replace ${PID} with the pid of the process
|
||||||
|
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||||
|
if err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||||
|
|
||||||
|
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||||
|
stopCmd.Stdout = p.processLogger
|
||||||
|
stopCmd.Stderr = p.processLogger
|
||||||
|
stopCmd.Env = p.cmd.Env
|
||||||
|
|
||||||
|
if err := stopCmd.Run(); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
|
||||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
|
|||||||
// test that the automatic start returns the expected error type
|
// test that the automatic start returns the expected error type
|
||||||
func TestProcess_BrokenModelConfig(t *testing.T) {
|
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||||
// Create a process configuration
|
// Create a process configuration
|
||||||
config := ModelConfig{
|
config := config.ModelConfig{
|
||||||
Cmd: "nonexistent-command",
|
Cmd: "nonexistent-command",
|
||||||
Proxy: "http://127.0.0.1:9913",
|
Proxy: "http://127.0.0.1:9913",
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
@@ -106,8 +107,8 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
|
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed")
|
assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||||
@@ -248,18 +249,14 @@ func TestProcess_SwapState(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
{"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting},
|
||||||
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
{"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady},
|
||||||
{"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed},
|
|
||||||
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
{"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping},
|
||||||
|
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, nil, StateStopped},
|
||||||
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
{"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping},
|
||||||
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
{"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped},
|
||||||
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
{"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown},
|
||||||
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
{"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped},
|
||||||
{"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting},
|
|
||||||
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
{"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady},
|
||||||
{"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady},
|
|
||||||
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
{"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping},
|
||||||
{"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed},
|
|
||||||
{"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed},
|
|
||||||
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
|
||||||
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
|
||||||
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
|
||||||
@@ -329,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
|||||||
|
|
||||||
// should run and exit but interrupt the long checkHealthTimeout
|
// should run and exit but interrupt the long checkHealthTimeout
|
||||||
checkHealthTimeout := 5
|
checkHealthTimeout := 5
|
||||||
config := ModelConfig{
|
config := config.ModelConfig{
|
||||||
Cmd: "sleep 1",
|
Cmd: "sleep 1",
|
||||||
Proxy: "http://127.0.0.1:9913",
|
Proxy: "http://127.0.0.1:9913",
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
@@ -339,7 +336,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
|||||||
process.healthCheckLoopInterval = time.Second // make it faster
|
process.healthCheckLoopInterval = time.Second // make it faster
|
||||||
err := process.start()
|
err := process.start()
|
||||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||||
assert.Equal(t, process.CurrentState(), StateFailed)
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||||
@@ -398,12 +395,15 @@ func TestProcess_StopImmediately(t *testing.T) {
|
|||||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||||
// the upstream command
|
// the upstream command
|
||||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("skipping SIGTERM test on Windows ")
|
||||||
|
}
|
||||||
|
|
||||||
expectedMessage := "test_sigkill"
|
expectedMessage := "test_sigkill"
|
||||||
binaryPath := getSimpleResponderPath()
|
binaryPath := getSimpleResponderPath()
|
||||||
port := getTestPort()
|
port := getTestPort()
|
||||||
|
|
||||||
config := ModelConfig{
|
conf := config.ModelConfig{
|
||||||
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
|
||||||
// to force the process to exit
|
// to force the process to exit
|
||||||
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
|
||||||
@@ -411,7 +411,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
|
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
|
||||||
defer process.Stop()
|
defer process.Stop()
|
||||||
|
|
||||||
// reduce to make testing go faster
|
// reduce to make testing go faster
|
||||||
@@ -449,3 +449,46 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
// the request should have been interrupted by SIGKILL
|
// the request should have been interrupted by SIGKILL
|
||||||
<-waitChan
|
<-waitChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_StopCmd(t *testing.T) {
|
||||||
|
conf := getTestSimpleResponderConfig("test_stop_cmd")
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
conf.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
} else {
|
||||||
|
conf.CmdStop = "kill -TERM ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
|
||||||
|
expectedMessage := "test_env_not_emptied"
|
||||||
|
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
// ensure that the the default config does not blank out the inherited environment
|
||||||
|
configWEnv := conf
|
||||||
|
|
||||||
|
// ensure the additiona variables are appended to the process' environment
|
||||||
|
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
|
||||||
|
|
||||||
|
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
|
||||||
|
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
|
||||||
|
|
||||||
|
process1.start()
|
||||||
|
defer process1.Stop()
|
||||||
|
process2.start()
|
||||||
|
defer process2.Stop()
|
||||||
|
|
||||||
|
assert.NotZero(t, len(process1.cmd.Environ()))
|
||||||
|
assert.NotZero(t, len(process2.cmd.Environ()))
|
||||||
|
assert.Equal(t, len(process1.cmd.Environ())+2, len(process2.cmd.Environ()), "process2 should have 2 more environment variables than process1")
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,12 +5,14 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProcessGroup struct {
|
type ProcessGroup struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config Config
|
config config.Config
|
||||||
id string
|
id string
|
||||||
swap bool
|
swap bool
|
||||||
exclusive bool
|
exclusive bool
|
||||||
@@ -24,7 +26,7 @@ type ProcessGroup struct {
|
|||||||
lastUsedProcess string
|
lastUsedProcess string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||||
groupConfig, ok := config.Groups[id]
|
groupConfig, ok := config.Groups[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Unable to find configuration for group id: " + id)
|
panic("Unable to find configuration for group id: " + id)
|
||||||
@@ -60,10 +62,20 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
|
|||||||
if pg.swap {
|
if pg.swap {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
if pg.lastUsedProcess != modelID {
|
if pg.lastUsedProcess != modelID {
|
||||||
|
|
||||||
|
// is there something already running?
|
||||||
if pg.lastUsedProcess != "" {
|
if pg.lastUsedProcess != "" {
|
||||||
pg.processes[pg.lastUsedProcess].Stop()
|
pg.processes[pg.lastUsedProcess].Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wait for the request to the new model to be fully handled
|
||||||
|
// and prevent race conditions see issue #277
|
||||||
|
pg.processes[modelID].ProxyRequest(writer, request)
|
||||||
pg.lastUsedProcess = modelID
|
pg.lastUsedProcess = modelID
|
||||||
|
|
||||||
|
// short circuit and exit
|
||||||
|
pg.Unlock()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
pg.Unlock()
|
pg.Unlock()
|
||||||
}
|
}
|
||||||
@@ -76,6 +88,29 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
|||||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||||
|
pg.Lock()
|
||||||
|
|
||||||
|
process, exists := pg.processes[modelID]
|
||||||
|
if !exists {
|
||||||
|
pg.Unlock()
|
||||||
|
return fmt.Errorf("process not found for %s", modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pg.lastUsedProcess == modelID {
|
||||||
|
pg.lastUsedProcess = ""
|
||||||
|
}
|
||||||
|
pg.Unlock()
|
||||||
|
|
||||||
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
defer pg.Unlock()
|
defer pg.Unlock()
|
||||||
|
|||||||
@@ -4,21 +4,23 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
"model4": getTestSimpleResponderConfig("model4"),
|
"model4": getTestSimpleResponderConfig("model4"),
|
||||||
"model5": getTestSimpleResponderConfig("model5"),
|
"model5": getTestSimpleResponderConfig("model5"),
|
||||||
},
|
},
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
"G1": {
|
"G1": {
|
||||||
Swap: true,
|
Swap: true,
|
||||||
Exclusive: true,
|
Exclusive: true,
|
||||||
@@ -33,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
|||||||
})
|
})
|
||||||
|
|
||||||
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
||||||
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
||||||
assert.True(t, pg.HasMember("model5"))
|
assert.True(t, pg.HasMember("model5"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,32 +46,49 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
|||||||
assert.False(t, pg.HasMember("model3"))
|
assert.False(t, pg.HasMember("model3"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||||
|
// and multiple requests are made in parallel, only one process is running at a time.
|
||||||
|
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||||
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
// use the same listening so if a model is already running, it will fail
|
||||||
|
// this is a way to test that swap isolation is working
|
||||||
|
// properly when there are parallel requests made at the
|
||||||
|
// same time.
|
||||||
|
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
||||||
|
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
||||||
|
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
||||||
|
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
||||||
|
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
||||||
|
},
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
tests := []string{"model1", "model2"}
|
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Add(len(tests))
|
||||||
for _, modelName := range tests {
|
for _, modelName := range tests {
|
||||||
t.Run(modelName, func(t *testing.T) {
|
go func(modelName string) {
|
||||||
reqBody := `{"x", "y"}`
|
defer wg.Done()
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
}(modelName)
|
||||||
// make sure only one process is in the running state
|
|
||||||
count := 0
|
|
||||||
for _, process := range pg.processes {
|
|
||||||
if process.CurrentState() == StateReady {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.Equal(t, 1, count)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
@@ -15,6 +15,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -26,7 +28,7 @@ const (
|
|||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config Config
|
config config.Config
|
||||||
ginEngine *gin.Engine
|
ginEngine *gin.Engine
|
||||||
|
|
||||||
// logging
|
// logging
|
||||||
@@ -34,10 +36,16 @@ type ProxyManager struct {
|
|||||||
upstreamLogger *LogMonitor
|
upstreamLogger *LogMonitor
|
||||||
muxLogger *LogMonitor
|
muxLogger *LogMonitor
|
||||||
|
|
||||||
|
metricsMonitor *MetricsMonitor
|
||||||
|
|
||||||
processGroups map[string]*ProcessGroup
|
processGroups map[string]*ProcessGroup
|
||||||
|
|
||||||
|
// shutdown signaling
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config Config) *ProxyManager {
|
func New(config config.Config) *ProxyManager {
|
||||||
// set up loggers
|
// set up loggers
|
||||||
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
||||||
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
||||||
@@ -65,6 +73,8 @@ func New(config Config) *ProxyManager {
|
|||||||
upstreamLogger.SetLogLevel(LevelInfo)
|
upstreamLogger.SetLogLevel(LevelInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: config,
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
@@ -73,7 +83,12 @@ func New(config Config) *ProxyManager {
|
|||||||
muxLogger: stdoutLogger,
|
muxLogger: stdoutLogger,
|
||||||
upstreamLogger: upstreamLogger,
|
upstreamLogger: upstreamLogger,
|
||||||
|
|
||||||
|
metricsMonitor: NewMetricsMonitor(&config),
|
||||||
|
|
||||||
processGroups: make(map[string]*ProcessGroup),
|
processGroups: make(map[string]*ProcessGroup),
|
||||||
|
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownCancel: shutdownCancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the process groups
|
// create the process groups
|
||||||
@@ -83,6 +98,35 @@ func New(config Config) *ProxyManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pm.setupGinEngine()
|
pm.setupGinEngine()
|
||||||
|
|
||||||
|
// run any startup hooks
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
// do it in the background, don't block startup -- not sure if good idea yet
|
||||||
|
go func() {
|
||||||
|
discardWriter := &DiscardWriter{}
|
||||||
|
for _, realModelName := range config.Hooks.OnStartup.Preload {
|
||||||
|
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||||
|
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
event.Emit(ModelPreloadedEvent{
|
||||||
|
ModelName: realModelName,
|
||||||
|
Success: false,
|
||||||
|
})
|
||||||
|
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
req, _ := http.NewRequest("GET", "/", nil)
|
||||||
|
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||||
|
event.Emit(ModelPreloadedEvent{
|
||||||
|
ModelName: realModelName,
|
||||||
|
Success: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
return pm
|
return pm
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,14 +185,27 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
c.Next()
|
c.Next()
|
||||||
})
|
})
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
mm := MetricsMiddleware(pm)
|
||||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
|
||||||
// Support legacy /v1/completions api, see issue #12
|
|
||||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
|
||||||
|
|
||||||
// Support embeddings
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
// Support legacy /v1/completions api, see issue #12
|
||||||
|
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// Support embeddings and reranking
|
||||||
|
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// llama-server's /reranking endpoint + aliases
|
||||||
|
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler)
|
||||||
|
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
|
||||||
|
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
|
||||||
|
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// llama-server's /infill endpoint for code infilling
|
||||||
|
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
|
// llama-server's /completion endpoint
|
||||||
|
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
@@ -159,42 +216,63 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
// in proxymanager_loghandlers.go
|
// in proxymanager_loghandlers.go
|
||||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
|
||||||
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
|
||||||
|
|
||||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* User Interface Endpoints
|
||||||
|
*/
|
||||||
pm.ginEngine.GET("/", func(c *gin.Context) {
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
||||||
// Set the Content-Type header to text/html
|
c.Redirect(http.StatusFound, "/ui")
|
||||||
c.Header("Content-Type", "text/html")
|
})
|
||||||
|
|
||||||
// Write the embedded HTML content to the response
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||||
htmlData, err := getHTMLFile("index.html")
|
c.Redirect(http.StatusFound, "/ui/models")
|
||||||
if err != nil {
|
})
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||||
return
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||||
}
|
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||||
_, err = c.Writer.Write(htmlData)
|
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||||
if err != nil {
|
c.String(http.StatusOK, "OK")
|
||||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
||||||
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
||||||
c.Data(http.StatusOK, "image/x-icon", data)
|
c.Data(http.StatusOK, "image/x-icon", data)
|
||||||
} else {
|
} else {
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
reactFS, err := GetReactFS()
|
||||||
|
if err != nil {
|
||||||
|
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
||||||
|
} else {
|
||||||
|
|
||||||
|
// serve files that exist under /ui/*
|
||||||
|
pm.ginEngine.StaticFS("/ui", reactFS)
|
||||||
|
|
||||||
|
// server SPA for UI under /ui/*
|
||||||
|
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
||||||
|
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
||||||
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := reactFS.Open("index.html")
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// see: proxymanager_api.go
|
||||||
|
// add API handler functions
|
||||||
|
addApiHandlers(pm)
|
||||||
|
|
||||||
// Disable console color for testing
|
// Disable console color for testing
|
||||||
gin.DisableConsoleColor()
|
gin.DisableConsoleColor()
|
||||||
}
|
}
|
||||||
@@ -242,6 +320,7 @@ func (pm *ProxyManager) Shutdown() {
|
|||||||
}(processGroup)
|
}(processGroup)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
pm.shutdownCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||||
@@ -269,76 +348,121 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := []interface{}{}
|
data := make([]gin.H, 0, len(pm.config.Models))
|
||||||
|
createdTime := time.Now().Unix()
|
||||||
|
|
||||||
for id, modelConfig := range pm.config.Models {
|
for id, modelConfig := range pm.config.Models {
|
||||||
if modelConfig.Unlisted {
|
if modelConfig.Unlisted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
data = append(data, map[string]interface{}{
|
record := gin.H{
|
||||||
"id": id,
|
"id": id,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": time.Now().Unix(),
|
"created": createdTime,
|
||||||
"owned_by": "llama-swap",
|
"owned_by": "llama-swap",
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||||
|
record["name"] = name
|
||||||
|
}
|
||||||
|
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||||
|
record["description"] = desc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add metadata if present
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
record["meta"] = gin.H{
|
||||||
|
"llamaswap": modelConfig.Metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data = append(data, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the Content-Type header to application/json
|
// Sort by the "id" key
|
||||||
c.Header("Content-Type", "application/json")
|
sort.Slice(data, func(i, j int) bool {
|
||||||
|
si, _ := data[i]["id"].(string)
|
||||||
|
sj, _ := data[j]["id"].(string)
|
||||||
|
return si < sj
|
||||||
|
})
|
||||||
|
|
||||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
// Set CORS headers if origin exists
|
||||||
|
if origin := c.GetHeader("Origin"); origin != "" {
|
||||||
c.Header("Access-Control-Allow-Origin", origin)
|
c.Header("Access-Control-Allow-Origin", origin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode the data as JSON and write it to the response writer
|
// Use gin's JSON method which handles content-type and encoding
|
||||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
c.JSON(http.StatusOK, gin.H{
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
"object": "list",
|
||||||
return
|
"data": data,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||||
requestedModel := c.Param("model_id")
|
upstreamPath := c.Param("upstreamPath")
|
||||||
|
|
||||||
if requestedModel == "" {
|
// split the upstream path by / and search for the model name
|
||||||
|
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
|
||||||
|
if len(parts) == 0 {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
modelFound := false
|
||||||
|
searchModelName := ""
|
||||||
|
var modelName, remainingPath string
|
||||||
|
for i, part := range parts {
|
||||||
|
if parts[i] == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if searchModelName == "" {
|
||||||
|
searchModelName = part
|
||||||
|
} else {
|
||||||
|
searchModelName = searchModelName + "/" + parts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||||
|
modelName = real
|
||||||
|
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||||
|
modelFound = true
|
||||||
|
|
||||||
|
// Check if this is exactly a model name with no additional path
|
||||||
|
// and doesn't end with a trailing slash
|
||||||
|
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||||
|
// Build new URL with query parameters preserved
|
||||||
|
newPath := "/upstream/" + searchModelName + "/"
|
||||||
|
if c.Request.URL.RawQuery != "" {
|
||||||
|
newPath += "?" + c.Request.URL.RawQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use 308 for non-GET/HEAD requests to preserve method
|
||||||
|
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||||
|
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||||
|
} else {
|
||||||
|
c.Redirect(http.StatusPermanentRedirect, newPath)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modelFound {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
c.Request.URL.Path = c.Param("upstreamPath")
|
c.Request.URL.Path = remainingPath
|
||||||
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
|
||||||
var html strings.Builder
|
|
||||||
|
|
||||||
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
|
|
||||||
|
|
||||||
// Extract keys and sort them
|
|
||||||
var modelIDs []string
|
|
||||||
for modelID, modelConfig := range pm.config.Models {
|
|
||||||
if modelConfig.Unlisted {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
modelIDs = append(modelIDs, modelID)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIDs)
|
|
||||||
|
|
||||||
// Iterate over sorted keys
|
|
||||||
for _, modelID := range modelIDs {
|
|
||||||
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
|
|
||||||
}
|
|
||||||
html.WriteString("</ul></body></html>")
|
|
||||||
c.Header("Content-Type", "text/html")
|
|
||||||
c.String(http.StatusOK, html.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||||
@@ -354,7 +478,13 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
@@ -370,6 +500,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// issue #174 strip parameters from the JSON body
|
||||||
|
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
||||||
|
if err != nil { // just log it and continue
|
||||||
|
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
||||||
|
} else {
|
||||||
|
for _, param := range stripParams {
|
||||||
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
||||||
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
|
|||||||
@@ -0,0 +1,229 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Unlisted bool `json:"unlisted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func addApiHandlers(pm *ProxyManager) {
|
||||||
|
// Add API endpoints for React to consume
|
||||||
|
apiGroup := pm.ginEngine.Group("/api")
|
||||||
|
{
|
||||||
|
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||||
|
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||||
|
apiGroup.GET("/events", pm.apiSendEvents)
|
||||||
|
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiUnloadAllModels(c *gin.Context) {
|
||||||
|
pm.StopProcesses(StopImmediately)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"msg": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) getModelStatus() []Model {
|
||||||
|
// Extract keys and sort them
|
||||||
|
models := []Model{}
|
||||||
|
|
||||||
|
modelIDs := make([]string, 0, len(pm.config.Models))
|
||||||
|
for modelID := range pm.config.Models {
|
||||||
|
modelIDs = append(modelIDs, modelID)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIDs)
|
||||||
|
|
||||||
|
// Iterate over sorted keys
|
||||||
|
for _, modelID := range modelIDs {
|
||||||
|
// Get process state
|
||||||
|
processGroup := pm.findGroupByModelName(modelID)
|
||||||
|
state := "unknown"
|
||||||
|
if processGroup != nil {
|
||||||
|
process := processGroup.processes[modelID]
|
||||||
|
if process != nil {
|
||||||
|
var stateStr string
|
||||||
|
switch process.CurrentState() {
|
||||||
|
case StateReady:
|
||||||
|
stateStr = "ready"
|
||||||
|
case StateStarting:
|
||||||
|
stateStr = "starting"
|
||||||
|
case StateStopping:
|
||||||
|
stateStr = "stopping"
|
||||||
|
case StateShutdown:
|
||||||
|
stateStr = "shutdown"
|
||||||
|
case StateStopped:
|
||||||
|
stateStr = "stopped"
|
||||||
|
default:
|
||||||
|
stateStr = "unknown"
|
||||||
|
}
|
||||||
|
state = stateStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
models = append(models, Model{
|
||||||
|
Id: modelID,
|
||||||
|
Name: pm.config.Models[modelID].Name,
|
||||||
|
Description: pm.config.Models[modelID].Description,
|
||||||
|
State: state,
|
||||||
|
Unlisted: pm.config.Models[modelID].Unlisted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
msgTypeModelStatus messageType = "modelStatus"
|
||||||
|
msgTypeLogData messageType = "logData"
|
||||||
|
msgTypeMetrics messageType = "metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
type messageEnvelope struct {
|
||||||
|
Type messageType `json:"type"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// sends a stream of different message types that happen on the server
|
||||||
|
func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
// prevent nginx from buffering SSE
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
sendBuffer := make(chan messageEnvelope, 25)
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
sendModels := func() {
|
||||||
|
data, err := json.Marshal(pm.getModelStatus())
|
||||||
|
if err == nil {
|
||||||
|
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
|
||||||
|
select {
|
||||||
|
case sendBuffer <- msg:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendLogData := func(source string, data []byte) {
|
||||||
|
data, err := json.Marshal(gin.H{
|
||||||
|
"source": source,
|
||||||
|
"data": string(data),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendMetrics := func(metrics []TokenMetrics) {
|
||||||
|
jsonData, err := json.Marshal(metrics)
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send updated models list
|
||||||
|
*/
|
||||||
|
defer event.On(func(e ProcessStateChangeEvent) {
|
||||||
|
sendModels()
|
||||||
|
})()
|
||||||
|
defer event.On(func(e ConfigFileChangedEvent) {
|
||||||
|
sendModels()
|
||||||
|
})()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send Log data
|
||||||
|
*/
|
||||||
|
defer pm.proxyLogger.OnLogData(func(data []byte) {
|
||||||
|
sendLogData("proxy", data)
|
||||||
|
})()
|
||||||
|
defer pm.upstreamLogger.OnLogData(func(data []byte) {
|
||||||
|
sendLogData("upstream", data)
|
||||||
|
})()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send Metrics data
|
||||||
|
*/
|
||||||
|
defer event.On(func(e TokenMetricsEvent) {
|
||||||
|
sendMetrics([]TokenMetrics{e.Metrics})
|
||||||
|
})()
|
||||||
|
|
||||||
|
// send initial batch of data
|
||||||
|
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||||
|
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||||
|
sendModels()
|
||||||
|
sendMetrics(pm.metricsMonitor.GetMetrics())
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
case <-pm.shutdownCtx.Done():
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
case msg := <-sendBuffer:
|
||||||
|
c.SSEvent("message", msg)
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||||
|
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", jsonData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||||
|
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||||
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
|
if processGroup == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,20 +12,7 @@ import (
|
|||||||
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) {
|
||||||
accept := c.GetHeader("Accept")
|
accept := c.GetHeader("Accept")
|
||||||
if strings.Contains(accept, "text/html") {
|
if strings.Contains(accept, "text/html") {
|
||||||
// Set the Content-Type header to text/html
|
c.Redirect(http.StatusFound, "/ui/")
|
||||||
c.Header("Content-Type", "text/html")
|
|
||||||
|
|
||||||
// Write the embedded HTML content to the response
|
|
||||||
logsHTML, err := getHTMLFile("logs.html")
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = c.Writer.Write(logsHTML)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
history := pm.muxLogger.GetHistory()
|
history := pm.muxLogger.GetHistory()
|
||||||
@@ -40,6 +28,8 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
c.Header("Transfer-Encoding", "chunked")
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
|
// prevent nginx from buffering streamed logs
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
logMonitorId := c.Param("logMonitorID")
|
logMonitorId := c.Param("logMonitorID")
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
logger, err := pm.getLogger(logMonitorId)
|
||||||
@@ -47,10 +37,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
c.String(http.StatusBadRequest, err.Error())
|
c.String(http.StatusBadRequest, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ch := logger.Subscribe()
|
|
||||||
defer logger.Unsubscribe(ch)
|
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
|
||||||
@@ -68,57 +55,28 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream new logs
|
sendChan := make(chan []byte, 10)
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
defer logger.OnLogData(func(data []byte) {
|
||||||
|
select {
|
||||||
|
case sendChan <- data:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-ch:
|
case <-c.Request.Context().Done():
|
||||||
_, err := c.Writer.Write(msg)
|
cancel()
|
||||||
if err != nil {
|
return
|
||||||
// just break the loop if we can't write for some reason
|
case <-pm.shutdownCtx.Done():
|
||||||
return
|
cancel()
|
||||||
}
|
return
|
||||||
|
case data := <-sendChan:
|
||||||
|
c.Writer.Write(data)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
case <-notify:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
|
||||||
|
|
||||||
logMonitorId := c.Param("logMonitorID")
|
|
||||||
logger, err := pm.getLogger(logMonitorId)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ch := logger.Subscribe()
|
|
||||||
defer logger.Unsubscribe(ch)
|
|
||||||
|
|
||||||
notify := c.Request.Context().Done()
|
|
||||||
|
|
||||||
// Send history first if not skipped
|
|
||||||
_, skipHistory := c.GetQuery("no-history")
|
|
||||||
if !skipHistory {
|
|
||||||
history := logger.GetHistory()
|
|
||||||
if len(history) != 0 {
|
|
||||||
c.SSEvent("message", string(history))
|
|
||||||
c.Writer.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream new logs
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case msg := <-ch:
|
|
||||||
c.SSEvent("message", string(msg))
|
|
||||||
c.Writer.Flush()
|
|
||||||
case <-notify:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@@ -9,18 +10,21 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
@@ -40,16 +44,15 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
assert.Contains(t, w.Body.String(), modelName)
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
"G1": {
|
"G1": {
|
||||||
Swap: true,
|
Swap: true,
|
||||||
Exclusive: false,
|
Exclusive: false,
|
||||||
@@ -87,14 +90,14 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
// Test that a persistent group is not affected by the swapping behaviour of
|
// Test that a persistent group is not affected by the swapping behaviour of
|
||||||
// other groups.
|
// other groups.
|
||||||
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
// the forever group is persistent and should not be affected by model1
|
// the forever group is persistent and should not be affected by model1
|
||||||
"forever": {
|
"forever": {
|
||||||
Swap: true,
|
Swap: true,
|
||||||
@@ -131,9 +134,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
t.Skip("skipping slow test")
|
t.Skip("skipping slow test")
|
||||||
}
|
}
|
||||||
|
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
@@ -165,9 +168,11 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
var response map[string]string
|
var response map[string]interface{}
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
results[key] = response["responseMessage"]
|
result, ok := response["responseMessage"].(string)
|
||||||
|
assert.Equal(t, ok, true)
|
||||||
|
results[key] = result
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
}(key)
|
}(key)
|
||||||
|
|
||||||
@@ -183,11 +188,20 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||||
config := Config{
|
|
||||||
|
model1Config := getTestSimpleResponderConfig("model1")
|
||||||
|
model1Config.Name = "Model 1"
|
||||||
|
model1Config.Description = "Model 1 description is used for testing"
|
||||||
|
|
||||||
|
model2Config := getTestSimpleResponderConfig("model2")
|
||||||
|
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||||
|
model2Config.Description = " "
|
||||||
|
|
||||||
|
config := config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": model1Config,
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": model2Config,
|
||||||
"model3": getTestSimpleResponderConfig("model3"),
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -213,6 +227,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
var response struct {
|
var response struct {
|
||||||
Data []map[string]interface{} `json:"data"`
|
Data []map[string]interface{} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -227,6 +242,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
"model3": {},
|
"model3": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make all models
|
||||||
for _, model := range response.Data {
|
for _, model := range response.Data {
|
||||||
modelID, ok := model["id"].(string)
|
modelID, ok := model["id"].(string)
|
||||||
assert.True(t, ok, "model ID should be a string")
|
assert.True(t, ok, "model ID should be a string")
|
||||||
@@ -245,12 +261,156 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
ownedBy, ok := model["owned_by"].(string)
|
ownedBy, ok := model["owned_by"].(string)
|
||||||
assert.True(t, ok, "owned_by should be a string")
|
assert.True(t, ok, "owned_by should be a string")
|
||||||
assert.Equal(t, "llama-swap", ownedBy)
|
assert.Equal(t, "llama-swap", ownedBy)
|
||||||
|
|
||||||
|
// check for optional name and description
|
||||||
|
if modelID == "model1" {
|
||||||
|
name, ok := model["name"].(string)
|
||||||
|
assert.True(t, ok, "name should be a string")
|
||||||
|
assert.Equal(t, "Model 1", name)
|
||||||
|
description, ok := model["description"].(string)
|
||||||
|
assert.True(t, ok, "description should be a string")
|
||||||
|
assert.Equal(t, "Model 1 description is used for testing", description)
|
||||||
|
} else {
|
||||||
|
_, exists := model["name"]
|
||||||
|
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
||||||
|
_, exists = model["description"]
|
||||||
|
assert.False(t, exists, "unexpected description field for model: %s", modelID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure all expected models were returned
|
// Ensure all expected models were returned
|
||||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ListModelsHandler_WithMetadata(t *testing.T) {
|
||||||
|
// Process config through LoadConfigFromReader to apply macro substitution
|
||||||
|
configYaml := `
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
logLevel: error
|
||||||
|
startPort: 10000
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
macros:
|
||||||
|
PORT_NUM: 10001
|
||||||
|
TEMP: 0.7
|
||||||
|
NAME: "llama"
|
||||||
|
metadata:
|
||||||
|
port: ${PORT_NUM}
|
||||||
|
temperature: ${TEMP}
|
||||||
|
enabled: true
|
||||||
|
note: "Running on port ${PORT_NUM}"
|
||||||
|
nested:
|
||||||
|
value: ${TEMP}
|
||||||
|
model2:
|
||||||
|
cmd: /path/to/server -p ${PORT}
|
||||||
|
`
|
||||||
|
processedConfig, err := config.LoadConfigFromReader(strings.NewReader(configYaml))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
proxy := New(processedConfig)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var response struct {
|
||||||
|
Data []map[string]any `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, response.Data, 2)
|
||||||
|
|
||||||
|
// Find model1 and model2 in response
|
||||||
|
var model1Data, model2Data map[string]any
|
||||||
|
for _, model := range response.Data {
|
||||||
|
if model["id"] == "model1" {
|
||||||
|
model1Data = model
|
||||||
|
} else if model["id"] == "model2" {
|
||||||
|
model2Data = model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify model1 has llamaswap_meta
|
||||||
|
assert.NotNil(t, model1Data)
|
||||||
|
meta, exists := model1Data["meta"]
|
||||||
|
if !assert.True(t, exists, "model1 should have meta key") {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
metaMap := meta.(map[string]any)
|
||||||
|
|
||||||
|
lsmeta, exists := metaMap["llamaswap"]
|
||||||
|
if !assert.True(t, exists, "model1 should have meta.llamaswap key") {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
lsmetamap := lsmeta.(map[string]any)
|
||||||
|
|
||||||
|
// Verify type preservation
|
||||||
|
assert.Equal(t, float64(10001), lsmetamap["port"]) // JSON numbers are float64
|
||||||
|
assert.Equal(t, 0.7, lsmetamap["temperature"])
|
||||||
|
assert.Equal(t, true, lsmetamap["enabled"])
|
||||||
|
// Verify string interpolation
|
||||||
|
assert.Equal(t, "Running on port 10001", lsmetamap["note"])
|
||||||
|
// Verify nested structure
|
||||||
|
nested := lsmetamap["nested"].(map[string]any)
|
||||||
|
assert.Equal(t, 0.7, nested["value"])
|
||||||
|
|
||||||
|
// Verify model2 does NOT have llamaswap_meta
|
||||||
|
assert.NotNil(t, model2Data)
|
||||||
|
_, exists = model2Data["llamaswap_meta"]
|
||||||
|
assert.False(t, exists, "model2 should not have llamaswap_meta")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
||||||
|
// Intentionally add models in non-sorted order and with an unlisted model
|
||||||
|
config := config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"zeta": getTestSimpleResponderConfig("zeta"),
|
||||||
|
"alpha": getTestSimpleResponderConfig("alpha"),
|
||||||
|
"beta": getTestSimpleResponderConfig("beta"),
|
||||||
|
"hidden": func() config.ModelConfig {
|
||||||
|
mc := getTestSimpleResponderConfig("hidden")
|
||||||
|
mc.Unlisted = true
|
||||||
|
return mc
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
|
||||||
|
// Request models list
|
||||||
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var response struct {
|
||||||
|
Data []map[string]interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We expect only the listed models in sorted order by id
|
||||||
|
expectedOrder := []string{"alpha", "beta", "zeta"}
|
||||||
|
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
|
||||||
|
got := make([]string, 0, len(response.Data))
|
||||||
|
for _, m := range response.Data {
|
||||||
|
id, _ := m["id"].(string)
|
||||||
|
got = append(got, id)
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxyManager_Shutdown(t *testing.T) {
|
func TestProxyManager_Shutdown(t *testing.T) {
|
||||||
// make broken model configurations
|
// make broken model configurations
|
||||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||||
@@ -262,15 +422,15 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
|
||||||
model3Config.Proxy = "http://localhost:10003/"
|
model3Config.Proxy = "http://localhost:10003/"
|
||||||
|
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": model1Config,
|
"model1": model1Config,
|
||||||
"model2": model2Config,
|
"model2": model2Config,
|
||||||
"model3": model3Config,
|
"model3": model3Config,
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
Groups: map[string]GroupConfig{
|
Groups: map[string]config.GroupConfig{
|
||||||
"test": {
|
"test": {
|
||||||
Swap: false,
|
Swap: false,
|
||||||
Members: []string{"model1", "model2", "model3"},
|
Members: []string{"model1", "model2", "model3"},
|
||||||
@@ -305,38 +465,92 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_Unload(t *testing.T) {
|
func TestProxyManager_Unload(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(conf)
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||||
req = httptest.NewRequest("GET", "/unload", nil)
|
req = httptest.NewRequest("GET", "/unload", nil)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Equal(t, w.Body.String(), "OK")
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
|
|
||||||
// give it a bit of time to stop
|
select {
|
||||||
<-time.After(time.Millisecond * 250)
|
case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
|
||||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
// good
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for model1 to stop")
|
||||||
|
}
|
||||||
|
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
||||||
|
const testGroupId = "testGroup"
|
||||||
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
testGroupId: {
|
||||||
|
Swap: false,
|
||||||
|
Members: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
// start both model
|
||||||
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
|
||||||
|
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
if !assert.Equal(t, w.Body.String(), "OK") {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
|
||||||
|
// good
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for model1 to stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
|
||||||
|
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||||
// Shared configuration
|
// Shared configuration
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
"model2": getTestSimpleResponderConfig("model2"),
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
},
|
},
|
||||||
@@ -399,9 +613,9 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -452,15 +666,15 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
|
||||||
modelConfig.UseModelName = upstreamModelName
|
modelConfig.UseModelName = upstreamModelName
|
||||||
|
|
||||||
config := AddDefaultGroupToConfig(Config{
|
conf := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": modelConfig,
|
"model1": modelConfig,
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
})
|
})
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(conf)
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
requestedModel := "model1"
|
requestedModel := "model1"
|
||||||
@@ -515,9 +729,9 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -583,27 +797,40 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_Upstream(t *testing.T) {
|
func TestProxyManager_Upstream(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
configStr := fmt.Sprintf(`
|
||||||
HealthCheckTimeout: 15,
|
logLevel: error
|
||||||
Models: map[string]ModelConfig{
|
models:
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
model1:
|
||||||
},
|
cmd: %s -port ${PORT} -silent -respond model1
|
||||||
LogLevel: "error",
|
aliases: [model-alias]
|
||||||
})
|
`, getSimpleResponderPath())
|
||||||
|
|
||||||
|
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
t.Run("main model name", func(t *testing.T) {
|
||||||
rec := httptest.NewRecorder()
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||||
proxy.ServeHTTP(rec, req)
|
rec := httptest.NewRecorder()
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model alias", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]config.ModelConfig{
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
@@ -618,8 +845,275 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
|||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
var response map[string]string
|
var response map[string]interface{}
|
||||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
assert.Equal(t, "81", response["h_content_length"])
|
assert.Equal(t, "81", response["h_content_length"])
|
||||||
assert.Equal(t, "model1", response["responseMessage"])
|
assert.Equal(t, "model1", response["responseMessage"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||||
|
modelConfig := getTestSimpleResponderConfig("model1")
|
||||||
|
modelConfig.Filters = config.ModelFilters{
|
||||||
|
StripParams: "temperature, model, stream",
|
||||||
|
}
|
||||||
|
|
||||||
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
LogLevel: "error",
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": modelConfig,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
var response map[string]interface{}
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
|
||||||
|
// `temperature` and `stream` are gone but model remains
|
||||||
|
assert.Equal(t, `{"model":"model1", "x_param":"123", "y_param":"abc"}`, response["request_body"])
|
||||||
|
|
||||||
|
// assert.Nil(t, response["temperature"])
|
||||||
|
// assert.Equal(t, "123", response["x_param"])
|
||||||
|
// assert.Equal(t, "abc", response["y_param"])
|
||||||
|
// t.Logf("%v", response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||||
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
// Make a non-streaming request
|
||||||
|
reqBody := `{"model":"model1", "stream": false}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
// Check that metrics were recorded
|
||||||
|
metrics := proxy.metricsMonitor.GetMetrics()
|
||||||
|
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the last metric has the correct model
|
||||||
|
lastMetric := metrics[len(metrics)-1]
|
||||||
|
assert.Equal(t, "model1", lastMetric.Model)
|
||||||
|
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||||
|
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||||
|
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||||
|
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||||
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
// Make a streaming request
|
||||||
|
reqBody := `{"model":"model1", "stream": true}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
// Check that metrics were recorded
|
||||||
|
metrics := proxy.metricsMonitor.GetMetrics()
|
||||||
|
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the last metric has the correct model
|
||||||
|
lastMetric := metrics[len(metrics)-1]
|
||||||
|
assert.Equal(t, "model1", lastMetric.Model)
|
||||||
|
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||||
|
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||||
|
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||||
|
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||||
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "OK", rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the custom llama-server /completion endpoint proxies correctly
|
||||||
|
func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
||||||
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
reqBody := `{"model":"model1"}`
|
||||||
|
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "model1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_StartupHooks(t *testing.T) {
|
||||||
|
|
||||||
|
// using real YAML as the configuration has gotten more complex
|
||||||
|
// is the right approach as LoadConfigFromReader() does a lot more
|
||||||
|
// than parse YAML now. Eventually migrate all tests to use this approach
|
||||||
|
configStr := strings.Replace(`
|
||||||
|
logLevel: error
|
||||||
|
hooks:
|
||||||
|
on_startup:
|
||||||
|
preload:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
|
groups:
|
||||||
|
preloadTestGroup:
|
||||||
|
swap: false
|
||||||
|
members:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model1
|
||||||
|
model2:
|
||||||
|
cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model2
|
||||||
|
`, "${simpleresponderpath}", simpleResponderPath, -1)
|
||||||
|
|
||||||
|
// Create a test model configuration
|
||||||
|
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||||
|
if !assert.NoError(t, err, "Invalid configuration") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events
|
||||||
|
|
||||||
|
unsub := event.On(func(e ModelPreloadedEvent) {
|
||||||
|
preloadChan <- e
|
||||||
|
})
|
||||||
|
|
||||||
|
defer unsub()
|
||||||
|
|
||||||
|
// Create the proxy which should trigger preloading
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
select {
|
||||||
|
case <-preloadChan:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for models to preload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// make sure they are both loaded
|
||||||
|
_, foundGroup := proxy.processGroups["preloadTestGroup"]
|
||||||
|
if !assert.True(t, foundGroup, "preloadTestGroup should exist") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model1"].CurrentState())
|
||||||
|
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model2"].CurrentState())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
||||||
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
endpoints := []string{
|
||||||
|
"/api/events",
|
||||||
|
"/logs/stream",
|
||||||
|
"/logs/stream/proxy",
|
||||||
|
"/logs/stream/upstream",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, endpoint := range endpoints {
|
||||||
|
t.Run(endpoint, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", endpoint, nil)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// We don't need the handler to fully complete, just to set the headers
|
||||||
|
// so run it in a goroutine and check the headers after a short delay
|
||||||
|
go proxy.ServeHTTP(rec, req)
|
||||||
|
time.Sleep(10 * time.Millisecond) // give it time to start and write headers
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
|
||||||
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"streaming-model": getTestSimpleResponderConfig("streaming-model"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|
||||||
|
// Make a streaming request
|
||||||
|
reqBody := `{"model":"streaming-model"}`
|
||||||
|
// simple-responder will return text/event-stream when stream=true is in the query
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
||||||
|
assert.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed ui_dist
|
||||||
|
var reactStaticFS embed.FS
|
||||||
|
|
||||||
|
// GetReactFS returns the embedded React filesystem
|
||||||
|
func GetReactFS() (http.FileSystem, error) {
|
||||||
|
subFS, err := fs.Sub(reactStaticFS, "ui_dist")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return http.FS(subFS), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReactIndexHTML returns the main index.html for the React app
|
||||||
|
func GetReactIndexHTML() ([]byte, error) {
|
||||||
|
return reactStaticFS.ReadFile("ui_dist/index.html")
|
||||||
|
}
|
||||||
@@ -0,0 +1,213 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
# This script installs llama-swap on Linux.
|
||||||
|
# It detects the current operating system architecture and installs the appropriate version of llama-swap.
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
LLAMA_SWAP_DEFAULT_ADDRESS=${LLAMA_SWAP_DEFAULT_ADDRESS:-"127.0.0.1:8080"}
|
||||||
|
|
||||||
|
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||||
|
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||||
|
|
||||||
|
status() { echo ">>> $*" >&2; }
|
||||||
|
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||||
|
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||||
|
|
||||||
|
available() { command -v "$1" >/dev/null; }
|
||||||
|
require() {
|
||||||
|
_MISSING=''
|
||||||
|
for TOOL in "$@"; do
|
||||||
|
if ! available "$TOOL"; then
|
||||||
|
_MISSING="$_MISSING $TOOL"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "$_MISSING"
|
||||||
|
}
|
||||||
|
|
||||||
|
SUDO=
|
||||||
|
if [ "$(id -u)" -ne 0 ]; then
|
||||||
|
if ! available sudo; then
|
||||||
|
error "This script requires superuser permissions. Please re-run as root."
|
||||||
|
fi
|
||||||
|
|
||||||
|
SUDO="sudo"
|
||||||
|
fi
|
||||||
|
|
||||||
|
NEEDS=$(require tee tar python3 mktemp)
|
||||||
|
if [ -n "$NEEDS" ]; then
|
||||||
|
status "ERROR: The following tools are required but missing:"
|
||||||
|
for NEED in $NEEDS; do
|
||||||
|
echo " - $NEED"
|
||||||
|
done
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||||
|
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) ARCH="amd64" ;;
|
||||||
|
aarch64|arm64) ARCH="arm64" ;;
|
||||||
|
*) error "Unsupported architecture: $ARCH" ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
IS_WSL2=false
|
||||||
|
|
||||||
|
KERN=$(uname -r)
|
||||||
|
case "$KERN" in
|
||||||
|
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
|
||||||
|
*icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||||
|
*) ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
download_binary() {
|
||||||
|
ASSET_NAME="linux_$ARCH"
|
||||||
|
|
||||||
|
TMPDIR=$(mktemp -d)
|
||||||
|
trap 'rm -rf "${TMPDIR}"' EXIT INT TERM HUP
|
||||||
|
PYTHON_SCRIPT=$(cat <<EOF
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
ASSET_NAME = "${ASSET_NAME}"
|
||||||
|
|
||||||
|
with urllib.request.urlopen("https://api.github.com/repos/mostlygeek/llama-swap/releases/latest") as resp:
|
||||||
|
data = json.load(resp)
|
||||||
|
for asset in data.get("assets", []):
|
||||||
|
if ASSET_NAME in asset.get("name", ""):
|
||||||
|
url = asset["browser_download_url"]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("ERROR: Matching asset not found.", file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("Downloading:", url, file=sys.stderr)
|
||||||
|
output_path = os.path.join("${TMPDIR}", "llama-swap.tar.gz")
|
||||||
|
urllib.request.urlretrieve(url, output_path)
|
||||||
|
print(output_path)
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
|
||||||
|
TARFILE=$(python3 -c "$PYTHON_SCRIPT")
|
||||||
|
if [ ! -f "$TARFILE" ]; then
|
||||||
|
error "Failed to download binary."
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Extracting to /usr/local/bin"
|
||||||
|
$SUDO tar -xzf "$TARFILE" -C /usr/local/bin llama-swap
|
||||||
|
}
|
||||||
|
download_binary
|
||||||
|
|
||||||
|
configure_systemd() {
|
||||||
|
if ! id llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Creating llama-swap user..."
|
||||||
|
$SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap
|
||||||
|
fi
|
||||||
|
if getent group render >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to render group..."
|
||||||
|
$SUDO usermod -a -G render llama-swap
|
||||||
|
fi
|
||||||
|
if getent group video >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to video group..."
|
||||||
|
$SUDO usermod -a -G video llama-swap
|
||||||
|
fi
|
||||||
|
if getent group docker >/dev/null 2>&1; then
|
||||||
|
status "Adding llama-swap user to docker group..."
|
||||||
|
$SUDO usermod -a -G docker llama-swap
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Adding current user to llama-swap group..."
|
||||||
|
$SUDO usermod -a -G llama-swap "$(whoami)"
|
||||||
|
|
||||||
|
if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||||
|
status "Creating default config.yaml..."
|
||||||
|
cat <<EOF | $SUDO -u llama-swap tee /usr/share/llama-swap/config.yaml >/dev/null
|
||||||
|
# default 15s likely to fail for default models due to downloading models
|
||||||
|
healthCheckTimeout: 60
|
||||||
|
|
||||||
|
models:
|
||||||
|
"qwen2.5":
|
||||||
|
cmd: |
|
||||||
|
docker run
|
||||||
|
--rm
|
||||||
|
-p \${PORT}:8080
|
||||||
|
--name qwen2.5
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
|
-hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M
|
||||||
|
cmdStop: docker stop qwen2.5
|
||||||
|
|
||||||
|
"smollm2":
|
||||||
|
cmd: |
|
||||||
|
docker run
|
||||||
|
--rm
|
||||||
|
-p \${PORT}:8080
|
||||||
|
--name smollm2
|
||||||
|
ghcr.io/ggml-org/llama.cpp:server
|
||||||
|
-hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M
|
||||||
|
cmdStop: docker stop smollm2
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
status "Creating llama-swap systemd service..."
|
||||||
|
cat <<EOF | $SUDO tee /etc/systemd/system/llama-swap.service >/dev/null
|
||||||
|
[Unit]
|
||||||
|
Description=llama-swap
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
User=llama-swap
|
||||||
|
Group=llama-swap
|
||||||
|
|
||||||
|
# set this to match your environment
|
||||||
|
ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config -listen ${LLAMA_SWAP_DEFAULT_ADDRESS}
|
||||||
|
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=3
|
||||||
|
StartLimitBurst=3
|
||||||
|
StartLimitInterval=30
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
EOF
|
||||||
|
SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)"
|
||||||
|
case $SYSTEMCTL_RUNNING in
|
||||||
|
running|degraded)
|
||||||
|
status "Enabling and starting llama-swap service..."
|
||||||
|
$SUDO systemctl daemon-reload
|
||||||
|
$SUDO systemctl enable llama-swap
|
||||||
|
|
||||||
|
start_service() { $SUDO systemctl restart llama-swap; }
|
||||||
|
trap start_service EXIT
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
warning "systemd is not running"
|
||||||
|
if [ "$IS_WSL2" = true ]; then
|
||||||
|
warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it"
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
if available systemctl; then
|
||||||
|
configure_systemd
|
||||||
|
fi
|
||||||
|
|
||||||
|
install_success() {
|
||||||
|
status "The llama-swap API is now available at http://${LLAMA_SWAP_DEFAULT_ADDRESS}"
|
||||||
|
status 'Customize the config file at /usr/share/llama-swap/config.yaml.'
|
||||||
|
status 'Install complete.'
|
||||||
|
}
|
||||||
|
|
||||||
|
# WSL2 only supports GPUs via nvidia passthrough
|
||||||
|
# so check for nvidia-smi to determine if GPU is available
|
||||||
|
if [ "$IS_WSL2" = true ]; then
|
||||||
|
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||||
|
status "Nvidia GPU detected."
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
install_success
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
# This script uninstalls llama-swap on Linux.
|
||||||
|
# It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group.
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||||
|
plain="$( (/usr/bin/tput sgr0 || :) 2>&-)"
|
||||||
|
|
||||||
|
status() { echo ">>> $*" >&2; }
|
||||||
|
error() { echo "${red}ERROR:${plain} $*"; exit 1; }
|
||||||
|
warning() { echo "${red}WARNING:${plain} $*"; }
|
||||||
|
|
||||||
|
available() { command -v $1 >/dev/null; }
|
||||||
|
|
||||||
|
SUDO=
|
||||||
|
if [ "$(id -u)" -ne 0 ]; then
|
||||||
|
if ! available sudo; then
|
||||||
|
error "This script requires superuser permissions. Please re-run as root."
|
||||||
|
fi
|
||||||
|
|
||||||
|
SUDO="sudo"
|
||||||
|
fi
|
||||||
|
|
||||||
|
configure_systemd() {
|
||||||
|
status "Stopping llama-swap service..."
|
||||||
|
$SUDO systemctl stop llama-swap
|
||||||
|
|
||||||
|
status "Disabling llama-swap service..."
|
||||||
|
$SUDO systemctl disable llama-swap
|
||||||
|
}
|
||||||
|
if available systemctl; then
|
||||||
|
configure_systemd
|
||||||
|
fi
|
||||||
|
|
||||||
|
if available llama-swap; then
|
||||||
|
status "Removing llama-swap binary..."
|
||||||
|
$SUDO rm $(which llama-swap)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -f "/usr/share/llama-swap/config.yaml" ]; then
|
||||||
|
while true; do
|
||||||
|
printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2
|
||||||
|
read answer
|
||||||
|
case "$answer" in
|
||||||
|
[Yy]* )
|
||||||
|
$SUDO rm -r /usr/share/llama-swap
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
[Nn]* | "" )
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
* )
|
||||||
|
echo "Invalid input. Please enter y or n."
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if id llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Removing llama-swap user..."
|
||||||
|
$SUDO userdel llama-swap
|
||||||
|
fi
|
||||||
|
|
||||||
|
if getent group llama-swap >/dev/null 2>&1; then
|
||||||
|
status "Removing llama-swap group..."
|
||||||
|
$SUDO groupdel llama-swap
|
||||||
|
fi
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
.vite
|
||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# React + TypeScript + Vite
|
||||||
|
|
||||||
|
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
||||||
|
|
||||||
|
Currently, two official plugins are available:
|
||||||
|
|
||||||
|
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
|
||||||
|
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
|
||||||
|
|
||||||
|
## Expanding the ESLint configuration
|
||||||
|
|
||||||
|
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
|
||||||
|
|
||||||
|
```js
|
||||||
|
export default tseslint.config({
|
||||||
|
extends: [
|
||||||
|
// Remove ...tseslint.configs.recommended and replace with this
|
||||||
|
...tseslint.configs.recommendedTypeChecked,
|
||||||
|
// Alternatively, use this for stricter rules
|
||||||
|
...tseslint.configs.strictTypeChecked,
|
||||||
|
// Optionally, add this for stylistic rules
|
||||||
|
...tseslint.configs.stylisticTypeChecked,
|
||||||
|
],
|
||||||
|
languageOptions: {
|
||||||
|
// other options...
|
||||||
|
parserOptions: {
|
||||||
|
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
||||||
|
tsconfigRootDir: import.meta.dirname,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
|
||||||
|
|
||||||
|
```js
|
||||||
|
// eslint.config.js
|
||||||
|
import reactX from 'eslint-plugin-react-x'
|
||||||
|
import reactDom from 'eslint-plugin-react-dom'
|
||||||
|
|
||||||
|
export default tseslint.config({
|
||||||
|
plugins: {
|
||||||
|
// Add the react-x and react-dom plugins
|
||||||
|
'react-x': reactX,
|
||||||
|
'react-dom': reactDom,
|
||||||
|
},
|
||||||
|
rules: {
|
||||||
|
// other rules...
|
||||||
|
// Enable its recommended typescript rules
|
||||||
|
...reactX.configs['recommended-typescript'].rules,
|
||||||
|
...reactDom.configs.recommended.rules,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
import js from '@eslint/js'
|
||||||
|
import globals from 'globals'
|
||||||
|
import reactHooks from 'eslint-plugin-react-hooks'
|
||||||
|
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||||
|
import tseslint from 'typescript-eslint'
|
||||||
|
|
||||||
|
export default tseslint.config(
|
||||||
|
{ ignores: ['dist'] },
|
||||||
|
{
|
||||||
|
extends: [js.configs.recommended, ...tseslint.configs.recommended],
|
||||||
|
files: ['**/*.{ts,tsx}'],
|
||||||
|
languageOptions: {
|
||||||
|
ecmaVersion: 2020,
|
||||||
|
globals: globals.browser,
|
||||||
|
},
|
||||||
|
plugins: {
|
||||||
|
'react-hooks': reactHooks,
|
||||||
|
'react-refresh': reactRefresh,
|
||||||
|
},
|
||||||
|
rules: {
|
||||||
|
...reactHooks.configs.recommended.rules,
|
||||||
|
'react-refresh/only-export-components': [
|
||||||
|
'warn',
|
||||||
|
{ allowConstantExport: true },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<link rel="icon" type="image/png" href="/favicon-96x96.png" sizes="96x96" />
|
||||||
|
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
||||||
|
<link rel="shortcut icon" href="/favicon.ico" />
|
||||||
|
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
|
||||||
|
<link rel="manifest" href="/site.webmanifest" />
|
||||||
|
<title>llama-swap</title>
|
||||||
|
</head>
|
||||||
|
<body >
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"name": "ui",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"start": "vite",
|
||||||
|
"build": "tsc -b && vite build --emptyOutDir",
|
||||||
|
"lint": "eslint .",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^19.1.0",
|
||||||
|
"react-dom": "^19.1.0",
|
||||||
|
"react-icons": "^5.5.0",
|
||||||
|
"react-resizable-panels": "^3.0.4",
|
||||||
|
"react-router-dom": "^7.6.2"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@eslint/js": "^9.25.0",
|
||||||
|
"@tailwindcss/vite": "^4.1.8",
|
||||||
|
"@types/react": "^19.1.2",
|
||||||
|
"@types/react-dom": "^19.1.2",
|
||||||
|
"@vitejs/plugin-react": "^4.4.1",
|
||||||
|
"eslint": "^9.25.0",
|
||||||
|
"eslint-plugin-react-hooks": "^5.2.0",
|
||||||
|
"eslint-plugin-react-refresh": "^0.4.19",
|
||||||
|
"globals": "^16.0.0",
|
||||||
|
"tailwindcss": "^4.1.8",
|
||||||
|
"typescript": "~5.8.3",
|
||||||
|
"typescript-eslint": "^8.30.1",
|
||||||
|
"vite": "^6.3.5"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
After Width: | Height: | Size: 5.9 KiB |
|
After Width: | Height: | Size: 2.2 KiB |
|
After Width: | Height: | Size: 15 KiB |
|
After Width: | Height: | Size: 38 KiB |
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"name": "llama-swap",
|
||||||
|
"short_name": "llama-swap",
|
||||||
|
"icons": [
|
||||||
|
{
|
||||||
|
"src": "/web-app-manifest-192x192.png",
|
||||||
|
"sizes": "192x192",
|
||||||
|
"type": "image/png",
|
||||||
|
"purpose": "maskable"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"src": "/web-app-manifest-512x512.png",
|
||||||
|
"sizes": "512x512",
|
||||||
|
"type": "image/png",
|
||||||
|
"purpose": "maskable"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"theme_color": "#ffffff",
|
||||||
|
"background_color": "#ffffff",
|
||||||
|
"display": "standalone"
|
||||||
|
}
|
||||||
|
After Width: | Height: | Size: 6.5 KiB |
|
After Width: | Height: | Size: 28 KiB |
@@ -0,0 +1,6 @@
|
|||||||
|
#root {
|
||||||
|
max-width: 1280px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 2rem;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
import { useEffect } from "react";
|
||||||
|
import { Navigate, Route, BrowserRouter as Router, Routes } from "react-router-dom";
|
||||||
|
import { Header } from "./components/Header";
|
||||||
|
import { useAPI } from "./contexts/APIProvider";
|
||||||
|
import { useTheme } from "./contexts/ThemeProvider";
|
||||||
|
import ActivityPage from "./pages/Activity";
|
||||||
|
import LogViewerPage from "./pages/LogViewer";
|
||||||
|
import ModelPage from "./pages/Models";
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
const { setConnectionState } = useTheme();
|
||||||
|
|
||||||
|
const { connectionStatus } = useAPI();
|
||||||
|
|
||||||
|
// Synchronize the window.title connections state with the actual connection state
|
||||||
|
useEffect(() => {
|
||||||
|
setConnectionState(connectionStatus);
|
||||||
|
}, [connectionStatus]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Router basename="/ui/">
|
||||||
|
<div className="flex flex-col h-screen">
|
||||||
|
<Header />
|
||||||
|
|
||||||
|
<main className="flex-1 overflow-auto p-4">
|
||||||
|
<Routes>
|
||||||
|
<Route path="/" element={<LogViewerPage />} />
|
||||||
|
<Route path="/models" element={<ModelPage />} />
|
||||||
|
<Route path="/activity" element={<ActivityPage />} />
|
||||||
|
<Route path="*" element={<Navigate to="/" replace />} />
|
||||||
|
</Routes>
|
||||||
|
</main>
|
||||||
|
</div>
|
||||||
|
</Router>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App;
|
||||||
|
After Width: | Height: | Size: 12 KiB |
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="35.93" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 228"><path fill="#00D8FF" d="M210.483 73.824a171.49 171.49 0 0 0-8.24-2.597c.465-1.9.893-3.777 1.273-5.621c6.238-30.281 2.16-54.676-11.769-62.708c-13.355-7.7-35.196.329-57.254 19.526a171.23 171.23 0 0 0-6.375 5.848a155.866 155.866 0 0 0-4.241-3.917C100.759 3.829 77.587-4.822 63.673 3.233C50.33 10.957 46.379 33.89 51.995 62.588a170.974 170.974 0 0 0 1.892 8.48c-3.28.932-6.445 1.924-9.474 2.98C17.309 83.498 0 98.307 0 113.668c0 15.865 18.582 31.778 46.812 41.427a145.52 145.52 0 0 0 6.921 2.165a167.467 167.467 0 0 0-2.01 9.138c-5.354 28.2-1.173 50.591 12.134 58.266c13.744 7.926 36.812-.22 59.273-19.855a145.567 145.567 0 0 0 5.342-4.923a168.064 168.064 0 0 0 6.92 6.314c21.758 18.722 43.246 26.282 56.54 18.586c13.731-7.949 18.194-32.003 12.4-61.268a145.016 145.016 0 0 0-1.535-6.842c1.62-.48 3.21-.974 4.76-1.488c29.348-9.723 48.443-25.443 48.443-41.52c0-15.417-17.868-30.326-45.517-39.844Zm-6.365 70.984c-1.4.463-2.836.91-4.3 1.345c-3.24-10.257-7.612-21.163-12.963-32.432c5.106-11 9.31-21.767 12.459-31.957c2.619.758 5.16 1.557 7.61 2.4c23.69 8.156 38.14 20.213 38.14 29.504c0 9.896-15.606 22.743-40.946 31.14Zm-10.514 20.834c2.562 12.94 2.927 24.64 1.23 33.787c-1.524 8.219-4.59 13.698-8.382 15.893c-8.067 4.67-25.32-1.4-43.927-17.412a156.726 156.726 0 0 1-6.437-5.87c7.214-7.889 14.423-17.06 21.459-27.246c12.376-1.098 24.068-2.894 34.671-5.345a134.17 134.17 0 0 1 1.386 6.193ZM87.276 214.515c-7.882 2.783-14.16 2.863-17.955.675c-8.075-4.657-11.432-22.636-6.853-46.752a156.923 156.923 0 0 1 1.869-8.499c10.486 2.32 22.093 3.988 34.498 4.994c7.084 9.967 14.501 19.128 21.976 27.15a134.668 134.668 0 0 1-4.877 4.492c-9.933 8.682-19.886 14.842-28.658 17.94ZM50.35 144.747c-12.483-4.267-22.792-9.812-29.858-15.863c-6.35-5.437-9.555-10.836-9.555-15.216c0-9.322 13.897-21.212 37.076-29.293c2.813-.98 5.757-1.905 8.812-2.773c3.204 10.42 7.406 21.315 12.477 32.332c-5.137 11.18-9.399 22.249-12.634 32.792a134.718 134.718 0 0 1-6.318-1.979Zm12.378-84.26c-4.811-24.587-1.616-43.134 6.425-47.789c8.564-4.958 27.502 2.111 47.463 19.835a144.318 144.318 0 0 1 3.841 3.545c-7.438 7.987-14.787 17.08-21.808 26.988c-12.04 1.116-23.565 2.908-34.161 5.309a160.342 160.342 0 0 1-1.76-7.887Zm110.427 27.268a347.8 347.8 0 0 0-7.785-12.803c8.168 1.033 15.994 2.404 23.343 4.08c-2.206 7.072-4.956 14.465-8.193 22.045a381.151 381.151 0 0 0-7.365-13.322Zm-45.032-43.861c5.044 5.465 10.096 11.566 15.065 18.186a322.04 322.04 0 0 0-30.257-.006c4.974-6.559 10.069-12.652 15.192-18.18ZM82.802 87.83a323.167 323.167 0 0 0-7.227 13.238c-3.184-7.553-5.909-14.98-8.134-22.152c7.304-1.634 15.093-2.97 23.209-3.984a321.524 321.524 0 0 0-7.848 12.897Zm8.081 65.352c-8.385-.936-16.291-2.203-23.593-3.793c2.26-7.3 5.045-14.885 8.298-22.6a321.187 321.187 0 0 0 7.257 13.246c2.594 4.48 5.28 8.868 8.038 13.147Zm37.542 31.03c-5.184-5.592-10.354-11.779-15.403-18.433c4.902.192 9.899.29 14.978.29c5.218 0 10.376-.117 15.453-.343c-4.985 6.774-10.018 12.97-15.028 18.486Zm52.198-57.817c3.422 7.8 6.306 15.345 8.596 22.52c-7.422 1.694-15.436 3.058-23.88 4.071a382.417 382.417 0 0 0 7.859-13.026a347.403 347.403 0 0 0 7.425-13.565Zm-16.898 8.101a358.557 358.557 0 0 1-12.281 19.815a329.4 329.4 0 0 1-23.444.823c-7.967 0-15.716-.248-23.178-.732a310.202 310.202 0 0 1-12.513-19.846h.001a307.41 307.41 0 0 1-10.923-20.627a310.278 310.278 0 0 1 10.89-20.637l-.001.001a307.318 307.318 0 0 1 12.413-19.761c7.613-.576 15.42-.876 23.31-.876H128c7.926 0 15.743.303 23.354.883a329.357 329.357 0 0 1 12.335 19.695a358.489 358.489 0 0 1 11.036 20.54a329.472 329.472 0 0 1-11 20.722Zm22.56-122.124c8.572 4.944 11.906 24.881 6.52 51.026c-.344 1.668-.73 3.367-1.15 5.09c-10.622-2.452-22.155-4.275-34.23-5.408c-7.034-10.017-14.323-19.124-21.64-27.008a160.789 160.789 0 0 1 5.888-5.4c18.9-16.447 36.564-22.941 44.612-18.3ZM128 90.808c12.625 0 22.86 10.235 22.86 22.86s-10.235 22.86-22.86 22.86s-22.86-10.235-22.86-22.86s10.235-22.86 22.86-22.86Z"></path></svg>
|
||||||
|
After Width: | Height: | Size: 4.0 KiB |
@@ -0,0 +1,26 @@
|
|||||||
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
|
import { useMemo } from "react";
|
||||||
|
|
||||||
|
const ConnectionStatusIcon = () => {
|
||||||
|
const { connectionStatus } = useAPI();
|
||||||
|
|
||||||
|
const eventStatusColor = useMemo(() => {
|
||||||
|
switch (connectionStatus) {
|
||||||
|
case "connected":
|
||||||
|
return "bg-emerald-500";
|
||||||
|
case "connecting":
|
||||||
|
return "bg-amber-500";
|
||||||
|
case "disconnected":
|
||||||
|
default:
|
||||||
|
return "bg-red-500";
|
||||||
|
}
|
||||||
|
}, [connectionStatus]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center" title={`event stream: ${connectionStatus}`}>
|
||||||
|
<span className={`inline-block w-3 h-3 rounded-full ${eventStatusColor} mr-2`}></span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ConnectionStatusIcon;
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
import { useCallback } from "react";
|
||||||
|
import { RiMoonFill, RiSunFill } from "react-icons/ri";
|
||||||
|
import { NavLink, type NavLinkRenderProps } from "react-router-dom";
|
||||||
|
import { useTheme } from "../contexts/ThemeProvider";
|
||||||
|
import ConnectionStatusIcon from "./ConnectionStatus";
|
||||||
|
|
||||||
|
export function Header() {
|
||||||
|
const { screenWidth, toggleTheme, isDarkMode, appTitle, setAppTitle, isNarrow } = useTheme();
|
||||||
|
const handleTitleChange = useCallback(
|
||||||
|
(newTitle: string) => {
|
||||||
|
setAppTitle(newTitle.replace(/\n/g, "").trim().substring(0, 64) || "llama-swap");
|
||||||
|
},
|
||||||
|
[setAppTitle]
|
||||||
|
);
|
||||||
|
|
||||||
|
const navLinkClass = ({ isActive }: NavLinkRenderProps) =>
|
||||||
|
`text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 ${isActive ? "font-semibold" : ""}`;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<header className={`flex items-center justify-between bg-surface border-b border-border px-4 ${isNarrow ? "py-1 h-[60px]" : "p-2 h-[75px]"}`}>
|
||||||
|
{screenWidth !== "xs" && screenWidth !== "sm" && (
|
||||||
|
<h1
|
||||||
|
contentEditable
|
||||||
|
suppressContentEditableWarning
|
||||||
|
className="p-0 outline-none hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
|
||||||
|
onBlur={(e) => handleTitleChange(e.currentTarget.textContent || "(set title)")}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === "Enter") {
|
||||||
|
e.preventDefault();
|
||||||
|
handleTitleChange(e.currentTarget.textContent || "(set title)");
|
||||||
|
e.currentTarget.blur();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{appTitle}
|
||||||
|
</h1>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<menu className="flex items-center gap-4">
|
||||||
|
<NavLink to="/" className={navLinkClass} type="button">
|
||||||
|
Logs
|
||||||
|
</NavLink>
|
||||||
|
<NavLink to="/models" className={navLinkClass} type="button">
|
||||||
|
Models
|
||||||
|
</NavLink>
|
||||||
|
<NavLink to="/activity" className={navLinkClass} type="button">
|
||||||
|
Activity
|
||||||
|
</NavLink>
|
||||||
|
<button className="" onClick={toggleTheme}>
|
||||||
|
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
|
||||||
|
</button>
|
||||||
|
<ConnectionStatusIcon />
|
||||||
|
</menu>
|
||||||
|
</header>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,246 @@
|
|||||||
|
import { createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
|
||||||
|
import type { ConnectionState } from "../lib/types";
|
||||||
|
|
||||||
|
type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
|
||||||
|
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
|
||||||
|
|
||||||
|
export interface Model {
|
||||||
|
id: string;
|
||||||
|
state: ModelStatus;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
unlisted: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface APIProviderType {
|
||||||
|
models: Model[];
|
||||||
|
listModels: () => Promise<Model[]>;
|
||||||
|
unloadAllModels: () => Promise<void>;
|
||||||
|
unloadSingleModel: (model: string) => Promise<void>;
|
||||||
|
loadModel: (model: string) => Promise<void>;
|
||||||
|
enableAPIEvents: (enabled: boolean) => void;
|
||||||
|
proxyLogs: string;
|
||||||
|
upstreamLogs: string;
|
||||||
|
metrics: Metrics[];
|
||||||
|
connectionStatus: ConnectionState;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Metrics {
|
||||||
|
id: number;
|
||||||
|
timestamp: string;
|
||||||
|
model: string;
|
||||||
|
cache_tokens: number;
|
||||||
|
input_tokens: number;
|
||||||
|
output_tokens: number;
|
||||||
|
prompt_per_second: number;
|
||||||
|
tokens_per_second: number;
|
||||||
|
duration_ms: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface LogData {
|
||||||
|
source: "upstream" | "proxy";
|
||||||
|
data: string;
|
||||||
|
}
|
||||||
|
interface APIEventEnvelope {
|
||||||
|
type: "modelStatus" | "logData" | "metrics";
|
||||||
|
data: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const APIContext = createContext<APIProviderType | undefined>(undefined);
|
||||||
|
type APIProviderProps = {
|
||||||
|
children: ReactNode;
|
||||||
|
autoStartAPIEvents?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
let apiEventSource: EventSource | null = null;
|
||||||
|
|
||||||
|
export function APIProvider({ children, autoStartAPIEvents = true }: APIProviderProps) {
|
||||||
|
const [proxyLogs, setProxyLogs] = useState("");
|
||||||
|
const [upstreamLogs, setUpstreamLogs] = useState("");
|
||||||
|
const [metrics, setMetrics] = useState<Metrics[]>([]);
|
||||||
|
const [connectionStatus, setConnectionState] = useState<ConnectionState>("disconnected");
|
||||||
|
//const apiEventSource = useRef<EventSource | null>(null);
|
||||||
|
|
||||||
|
const [models, setModels] = useState<Model[]>([]);
|
||||||
|
|
||||||
|
const appendLog = useCallback((newData: string, setter: React.Dispatch<React.SetStateAction<string>>) => {
|
||||||
|
setter((prev) => {
|
||||||
|
const updatedLog = prev + newData;
|
||||||
|
return updatedLog.length > LOG_LENGTH_LIMIT ? updatedLog.slice(-LOG_LENGTH_LIMIT) : updatedLog;
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const enableAPIEvents = useCallback((enabled: boolean) => {
|
||||||
|
if (!enabled) {
|
||||||
|
apiEventSource?.close();
|
||||||
|
apiEventSource = null;
|
||||||
|
setMetrics([]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let retryCount = 0;
|
||||||
|
const initialDelay = 1000; // 1 second
|
||||||
|
|
||||||
|
const connect = () => {
|
||||||
|
apiEventSource?.close();
|
||||||
|
apiEventSource = new EventSource("/api/events");
|
||||||
|
|
||||||
|
setConnectionState("connecting");
|
||||||
|
|
||||||
|
apiEventSource.onopen = () => {
|
||||||
|
// clear everything out on connect to keep things in sync
|
||||||
|
setProxyLogs("");
|
||||||
|
setUpstreamLogs("");
|
||||||
|
setMetrics([]); // clear metrics on reconnect
|
||||||
|
setModels([]); // clear models on reconnect
|
||||||
|
retryCount = 0;
|
||||||
|
setConnectionState("connected");
|
||||||
|
};
|
||||||
|
|
||||||
|
apiEventSource.onmessage = (e: MessageEvent) => {
|
||||||
|
try {
|
||||||
|
const message = JSON.parse(e.data) as APIEventEnvelope;
|
||||||
|
switch (message.type) {
|
||||||
|
case "modelStatus":
|
||||||
|
{
|
||||||
|
const models = JSON.parse(message.data) as Model[];
|
||||||
|
|
||||||
|
// sort models by name and id
|
||||||
|
models.sort((a, b) => {
|
||||||
|
return (a.name + a.id).localeCompare(b.name + b.id);
|
||||||
|
});
|
||||||
|
|
||||||
|
setModels(models);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "logData":
|
||||||
|
const logData = JSON.parse(message.data) as LogData;
|
||||||
|
switch (logData.source) {
|
||||||
|
case "proxy":
|
||||||
|
appendLog(logData.data, setProxyLogs);
|
||||||
|
break;
|
||||||
|
case "upstream":
|
||||||
|
appendLog(logData.data, setUpstreamLogs);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "metrics":
|
||||||
|
{
|
||||||
|
const newMetrics = JSON.parse(message.data) as Metrics[];
|
||||||
|
setMetrics((prevMetrics) => {
|
||||||
|
return [...newMetrics, ...prevMetrics];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error(e.data, err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
apiEventSource.onerror = () => {
|
||||||
|
apiEventSource?.close();
|
||||||
|
retryCount++;
|
||||||
|
const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
|
||||||
|
setConnectionState("disconnected");
|
||||||
|
setTimeout(connect, delay);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
connect();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (autoStartAPIEvents) {
|
||||||
|
enableAPIEvents(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
enableAPIEvents(false);
|
||||||
|
};
|
||||||
|
}, [enableAPIEvents, autoStartAPIEvents]);
|
||||||
|
|
||||||
|
const listModels = useCallback(async (): Promise<Model[]> => {
|
||||||
|
try {
|
||||||
|
const response = await fetch("/api/models/");
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`HTTP error! status: ${response.status}`);
|
||||||
|
}
|
||||||
|
const data = await response.json();
|
||||||
|
return data || [];
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to fetch models:", error);
|
||||||
|
return []; // Return empty array as fallback
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const unloadAllModels = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/api/models/unload`, {
|
||||||
|
method: "POST",
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to unload models: ${response.status}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to unload models:", error);
|
||||||
|
throw error; // Re-throw to let calling code handle it
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const unloadSingleModel = useCallback(async (model: string) => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/api/models/unload/${model}`, {
|
||||||
|
method: "POST",
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to unload model: ${response.status}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to unload model", model, error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const loadModel = useCallback(async (model: string) => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/upstream/${model}/`, {
|
||||||
|
method: "GET",
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to load model: ${response.status}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to load model:", error);
|
||||||
|
throw error; // Re-throw to let calling code handle it
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const value = useMemo(
|
||||||
|
() => ({
|
||||||
|
models,
|
||||||
|
listModels,
|
||||||
|
unloadAllModels,
|
||||||
|
unloadSingleModel,
|
||||||
|
loadModel,
|
||||||
|
enableAPIEvents,
|
||||||
|
proxyLogs,
|
||||||
|
upstreamLogs,
|
||||||
|
metrics,
|
||||||
|
connectionStatus,
|
||||||
|
}),
|
||||||
|
[models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
|
||||||
|
);
|
||||||
|
|
||||||
|
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useAPI() {
|
||||||
|
const context = useContext(APIContext);
|
||||||
|
if (context === undefined) {
|
||||||
|
throw new Error("useAPI must be used within an APIProvider");
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
import { createContext, useContext, useEffect, type ReactNode, useMemo, useState } from "react";
|
||||||
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
|
import type { ConnectionState } from "../lib/types";
|
||||||
|
|
||||||
|
type ScreenWidth = "xs" | "sm" | "md" | "lg" | "xl" | "2xl";
|
||||||
|
type ThemeContextType = {
|
||||||
|
isDarkMode: boolean;
|
||||||
|
screenWidth: ScreenWidth;
|
||||||
|
isNarrow: boolean;
|
||||||
|
toggleTheme: () => void;
|
||||||
|
|
||||||
|
// for managing the window title and connection state information
|
||||||
|
appTitle: string;
|
||||||
|
setAppTitle: (title: string) => void;
|
||||||
|
setConnectionState: (state: ConnectionState) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ThemeContext = createContext<ThemeContextType | undefined>(undefined);
|
||||||
|
|
||||||
|
type ThemeProviderProps = {
|
||||||
|
children: ReactNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function ThemeProvider({ children }: ThemeProviderProps) {
|
||||||
|
const [appTitle, setAppTitle] = usePersistentState("app-title", "llama-swap");
|
||||||
|
const [connectionState, setConnectionState] = useState<ConnectionState>("disconnected");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the document.title with informative information
|
||||||
|
*/
|
||||||
|
useEffect(() => {
|
||||||
|
const connectionIcon = connectionState === "connecting" ? "🟡" : connectionState === "connected" ? "🟢" : "🔴";
|
||||||
|
document.title = connectionIcon + " " + appTitle; // Set initial title
|
||||||
|
}, [appTitle, connectionState]);
|
||||||
|
|
||||||
|
const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false);
|
||||||
|
const [screenWidth, setScreenWidth] = useState<ScreenWidth>("md"); // Default to md
|
||||||
|
|
||||||
|
// matches tailwind classes
|
||||||
|
// https://tailwindcss.com/docs/responsive-design
|
||||||
|
useEffect(() => {
|
||||||
|
const checkInnerWidth = () => {
|
||||||
|
const innerWidth = window.innerWidth;
|
||||||
|
if (innerWidth < 640) {
|
||||||
|
setScreenWidth("xs");
|
||||||
|
} else if (innerWidth < 768) {
|
||||||
|
setScreenWidth("sm");
|
||||||
|
} else if (innerWidth < 1024) {
|
||||||
|
setScreenWidth("md");
|
||||||
|
} else if (innerWidth < 1280) {
|
||||||
|
setScreenWidth("lg");
|
||||||
|
} else if (innerWidth < 1536) {
|
||||||
|
setScreenWidth("xl");
|
||||||
|
} else {
|
||||||
|
setScreenWidth("2xl");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
checkInnerWidth();
|
||||||
|
window.addEventListener("resize", checkInnerWidth);
|
||||||
|
|
||||||
|
return () => window.removeEventListener("resize", checkInnerWidth);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light");
|
||||||
|
}, [isDarkMode]);
|
||||||
|
|
||||||
|
const toggleTheme = () => setIsDarkMode((prev) => !prev);
|
||||||
|
const isNarrow = useMemo(() => {
|
||||||
|
return screenWidth === "xs" || screenWidth === "sm" || screenWidth === "md";
|
||||||
|
}, [screenWidth]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ThemeContext.Provider
|
||||||
|
value={{
|
||||||
|
isDarkMode,
|
||||||
|
toggleTheme,
|
||||||
|
screenWidth,
|
||||||
|
isNarrow,
|
||||||
|
appTitle,
|
||||||
|
setAppTitle,
|
||||||
|
setConnectionState,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</ThemeContext.Provider>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useTheme(): ThemeContextType {
|
||||||
|
const context = useContext(ThemeContext);
|
||||||
|
if (context === undefined) {
|
||||||
|
throw new Error("useTheme must be used within a ThemeProvider");
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
import { useState, useEffect, useCallback } from "react";
|
||||||
|
|
||||||
|
export function usePersistentState<T>(key: string, initialValue: T): [T, (value: T | ((prevState: T) => T)) => void] {
|
||||||
|
const [state, setState] = useState<T>(() => {
|
||||||
|
if (typeof window === "undefined") return initialValue;
|
||||||
|
try {
|
||||||
|
const saved = localStorage.getItem(key);
|
||||||
|
return saved !== null ? JSON.parse(saved) : initialValue;
|
||||||
|
} catch (e) {
|
||||||
|
console.error(`Error parsing stored value for ${key}`, e);
|
||||||
|
return initialValue;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const setPersistentState = useCallback(
|
||||||
|
(value: T | ((prevState: T) => T)) => {
|
||||||
|
setState((prev) => {
|
||||||
|
const nextValue = typeof value === "function" ? (value as (prevState: T) => T)(prev) : value;
|
||||||
|
try {
|
||||||
|
localStorage.setItem(key, JSON.stringify(nextValue));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(`Error saving value for ${key}`, e);
|
||||||
|
}
|
||||||
|
return nextValue;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[key]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
try {
|
||||||
|
localStorage.setItem(key, JSON.stringify(state));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(`Error saving value for ${key}`, e);
|
||||||
|
}
|
||||||
|
}, [key, state]);
|
||||||
|
|
||||||
|
return [state, setPersistentState];
|
||||||
|
}
|
||||||
@@ -0,0 +1,176 @@
|
|||||||
|
@import "tailwindcss";
|
||||||
|
@custom-variant dark (&:where([data-theme=dark], [data-theme=dark] *));
|
||||||
|
|
||||||
|
@theme {
|
||||||
|
--color-background: rgba(252, 252, 249, 1);
|
||||||
|
--color-surface: rgba(255, 255, 253, 1);
|
||||||
|
|
||||||
|
/* text colors */
|
||||||
|
--color-txtmain: rgba(19, 52, 59, 1);
|
||||||
|
--color-txtsecondary: rgba(98, 108, 113, 1);
|
||||||
|
--color-navlink-active: rgba(245, 245, 245, 1);
|
||||||
|
|
||||||
|
--color-primary: rgba(50, 184, 198, 1);
|
||||||
|
|
||||||
|
--color-primary-hover: rgba(29, 116, 128, 1);
|
||||||
|
--color-primary-active: rgba(26, 104, 115, 1);
|
||||||
|
--color-secondary: rgba(94, 82, 64, 0.12);
|
||||||
|
--color-secondary-hover: rgba(94, 82, 64, 0.2);
|
||||||
|
--color-secondary-active: rgba(94, 82, 64, 0.25);
|
||||||
|
--color-border: rgba(94, 82, 64, 0.3);
|
||||||
|
--color-btn-primary-text: rgba(252, 252, 249, 1);
|
||||||
|
--color-card-border: rgba(94, 82, 64, 0.12);
|
||||||
|
--color-card-border-inner: rgba(94, 82, 64, 0.12);
|
||||||
|
--color-error: rgba(192, 21, 47, 1);
|
||||||
|
--color-success: rgba(33, 128, 141, 1);
|
||||||
|
--color-warning: rgb(244, 155, 0);
|
||||||
|
--color-info: rgba(98, 108, 113, 1);
|
||||||
|
--color-focus-ring: rgba(33, 128, 141, 0.4);
|
||||||
|
--color-select-caret: rgba(19, 52, 59, 0.8);
|
||||||
|
--color-btn-border: rgba(94, 82, 64, 0.7);
|
||||||
|
}
|
||||||
|
|
||||||
|
@layer theme {
|
||||||
|
/* over ride theme for dark mode */
|
||||||
|
[data-theme="dark"] {
|
||||||
|
--color-background: rgba(31, 33, 33, 1);
|
||||||
|
--color-surface: rgba(38, 40, 40, 1);
|
||||||
|
/* text colors */
|
||||||
|
--color-txtmain: rgba(245, 245, 245, 1);
|
||||||
|
--color-txtsecondary: rgba(167, 169, 169, 0.7);
|
||||||
|
|
||||||
|
--color-navlink-active: rgba(245, 245, 245, 1);
|
||||||
|
|
||||||
|
--color-primary: rgba(33, 128, 141, 1);
|
||||||
|
--color-primary-hover: rgba(45, 166, 178, 1);
|
||||||
|
--color-primary-active: rgba(41, 150, 161, 1);
|
||||||
|
--color-secondary: rgba(119, 124, 124, 0.15);
|
||||||
|
--color-secondary-hover: rgba(119, 124, 124, 0.25);
|
||||||
|
--color-secondary-active: rgba(119, 124, 124, 0.3);
|
||||||
|
--color-border: rgba(119, 124, 124, 0.3);
|
||||||
|
--color-error: rgba(255, 84, 89, 1);
|
||||||
|
--color-success: rgba(50, 184, 198, 1);
|
||||||
|
--color-warning: rgb(244, 155, 0);
|
||||||
|
--color-info: rgba(167, 169, 169, 1);
|
||||||
|
--color-focus-ring: rgba(50, 184, 198, 0.4);
|
||||||
|
--color-btn-primary-text: rgba(19, 52, 59, 1);
|
||||||
|
--color-card-border: rgba(119, 124, 124, 0.2);
|
||||||
|
--color-card-border-inner: rgba(119, 124, 124, 0.15);
|
||||||
|
--shadow-inset-sm: inset 0 1px 0 rgba(255, 255, 255, 0.1), inset 0 -1px 0 rgba(0, 0, 0, 0.15);
|
||||||
|
--button-border-secondary: rgba(119, 124, 124, 0.2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@layer base {
|
||||||
|
body {
|
||||||
|
/* example of how colors using theme colors*/
|
||||||
|
@apply bg-background text-txtmain;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
@apply text-4xl text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h2 {
|
||||||
|
@apply text-3xl text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h3 {
|
||||||
|
@apply text-2xl text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h4 {
|
||||||
|
@apply text-xl text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h5 {
|
||||||
|
@apply text-lg text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
h6 {
|
||||||
|
@apply text-base text-txtmain font-bold pb-4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* define CSS classes here for specific types of components */
|
||||||
|
@layer components {
|
||||||
|
.container {
|
||||||
|
@apply px-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Tables */
|
||||||
|
table th {
|
||||||
|
@apply p-2 font-semibold;
|
||||||
|
}
|
||||||
|
table td {
|
||||||
|
@apply p-2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Navigation Header */
|
||||||
|
|
||||||
|
.navlink {
|
||||||
|
@apply text-txtsecondary hover:bg-secondary hover:text-txtmain rounded-lg p-2;
|
||||||
|
}
|
||||||
|
.navlink.active {
|
||||||
|
@apply bg-primary text-navlink-active;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Card component */
|
||||||
|
.card {
|
||||||
|
@apply bg-surface rounded-lg border border-card-border shadow-sm overflow-hidden p-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card:hover {
|
||||||
|
@apply shadow-md;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card__body {
|
||||||
|
@apply p-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card__header,
|
||||||
|
.card__footer {
|
||||||
|
@apply p-4 border-b border-card-border-inner;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Status Badges */
|
||||||
|
.status {
|
||||||
|
@apply inline-block px-2 py-1 text-xs font-medium rounded-lg;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status--ready {
|
||||||
|
@apply bg-success/10 text-success;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status--starting,
|
||||||
|
.status--stopping {
|
||||||
|
@apply bg-warning/10 text-warning;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status--stopped {
|
||||||
|
@apply bg-error/10 text-error;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Buttons */
|
||||||
|
.btn {
|
||||||
|
@apply bg-surface py-2 px-4 text-sm rounded-md border transition-colors duration-200 border-btn-border;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn:hover {
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn--sm {
|
||||||
|
@apply px-2 py-0.5 text-xs;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn:disabled {
|
||||||
|
@apply opacity-50 cursor-not-allowed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@layer utilities {
|
||||||
|
.ml-2 {
|
||||||
|
margin-left: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.my-8 {
|
||||||
|
margin-top: 2rem;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
export type ConnectionState = "connected" | "connecting" | "disconnected";
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import { StrictMode } from "react";
|
||||||
|
import { createRoot } from "react-dom/client";
|
||||||
|
import "./index.css";
|
||||||
|
import App from "./App.tsx";
|
||||||
|
import { ThemeProvider } from "./contexts/ThemeProvider";
|
||||||
|
import { APIProvider } from "./contexts/APIProvider";
|
||||||
|
|
||||||
|
createRoot(document.getElementById("root")!).render(
|
||||||
|
<StrictMode>
|
||||||
|
<ThemeProvider>
|
||||||
|
<APIProvider>
|
||||||
|
<App />
|
||||||
|
</APIProvider>
|
||||||
|
</ThemeProvider>
|
||||||
|
</StrictMode>
|
||||||
|
);
|
||||||
@@ -0,0 +1,120 @@
|
|||||||
|
import { useMemo } from "react";
|
||||||
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
|
|
||||||
|
const formatSpeed = (speed: number): string => {
|
||||||
|
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||||
|
};
|
||||||
|
|
||||||
|
const formatDuration = (ms: number): string => {
|
||||||
|
return (ms / 1000).toFixed(2) + "s";
|
||||||
|
};
|
||||||
|
|
||||||
|
const formatRelativeTime = (timestamp: string): string => {
|
||||||
|
const now = new Date();
|
||||||
|
const date = new Date(timestamp);
|
||||||
|
const diffInSeconds = Math.floor((now.getTime() - date.getTime()) / 1000);
|
||||||
|
|
||||||
|
// Handle future dates by returning "just now"
|
||||||
|
if (diffInSeconds < 5) {
|
||||||
|
return "now";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diffInSeconds < 60) {
|
||||||
|
return `${diffInSeconds}s ago`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const diffInMinutes = Math.floor(diffInSeconds / 60);
|
||||||
|
if (diffInMinutes < 60) {
|
||||||
|
return `${diffInMinutes}m ago`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const diffInHours = Math.floor(diffInMinutes / 60);
|
||||||
|
if (diffInHours < 24) {
|
||||||
|
return `${diffInHours}h ago`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return "a while ago";
|
||||||
|
};
|
||||||
|
|
||||||
|
const ActivityPage = () => {
|
||||||
|
const { metrics } = useAPI();
|
||||||
|
const sortedMetrics = useMemo(() => {
|
||||||
|
return [...metrics].sort((a, b) => b.id - a.id);
|
||||||
|
}, [metrics]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="p-2">
|
||||||
|
<h1 className="text-2xl font-bold">Activity</h1>
|
||||||
|
|
||||||
|
{metrics.length === 0 && (
|
||||||
|
<div className="text-center py-8">
|
||||||
|
<p className="text-gray-600">No metrics data available</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{metrics.length > 0 && (
|
||||||
|
<div className="card overflow-auto">
|
||||||
|
<table className="min-w-full divide-y">
|
||||||
|
<thead className="border-gray-200 dark:border-white/10">
|
||||||
|
<tr className="text-left text-xs uppercase tracking-wider">
|
||||||
|
<th className="px-6 py-3">ID</th>
|
||||||
|
<th className="px-6 py-3">Time</th>
|
||||||
|
<th className="px-6 py-3">Model</th>
|
||||||
|
<th className="px-6 py-3">
|
||||||
|
Cached <Tooltip content="prompt tokens from cache" />
|
||||||
|
</th>
|
||||||
|
<th className="px-6 py-3">
|
||||||
|
Prompt <Tooltip content="new prompt tokens processed" />
|
||||||
|
</th>
|
||||||
|
<th className="px-6 py-3">Generated</th>
|
||||||
|
<th className="px-6 py-3">Prompt Processing</th>
|
||||||
|
<th className="px-6 py-3">Generation Speed</th>
|
||||||
|
<th className="px-6 py-3">Duration</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody className="divide-y">
|
||||||
|
{sortedMetrics.map((metric) => (
|
||||||
|
<tr key={metric.id} className="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
||||||
|
<td className="px-4 py-4">{metric.id + 1 /* un-zero index */}</td>
|
||||||
|
<td className="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
|
||||||
|
<td className="px-6 py-4">{metric.model}</td>
|
||||||
|
<td className="px-6 py-4">{metric.cache_tokens > 0 ? metric.cache_tokens.toLocaleString() : "-"}</td>
|
||||||
|
<td className="px-6 py-4">{metric.input_tokens.toLocaleString()}</td>
|
||||||
|
<td className="px-6 py-4">{metric.output_tokens.toLocaleString()}</td>
|
||||||
|
<td className="px-6 py-4">{formatSpeed(metric.prompt_per_second)}</td>
|
||||||
|
<td className="px-6 py-4">{formatSpeed(metric.tokens_per_second)}</td>
|
||||||
|
<td className="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
interface TooltipProps {
|
||||||
|
content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Tooltip: React.FC<TooltipProps> = ({ content }) => {
|
||||||
|
return (
|
||||||
|
<div className="relative group inline-block">
|
||||||
|
ⓘ
|
||||||
|
<div
|
||||||
|
className="absolute top-full left-1/2 transform -translate-x-1/2 mt-2
|
||||||
|
px-3 py-2 bg-gray-900 text-white text-sm rounded-md
|
||||||
|
opacity-0 group-hover:opacity-100 transition-opacity
|
||||||
|
duration-200 pointer-events-none whitespace-nowrap z-50 normal-case"
|
||||||
|
>
|
||||||
|
{content}
|
||||||
|
<div
|
||||||
|
className="absolute bottom-full left-1/2 transform -translate-x-1/2
|
||||||
|
border-4 border-transparent border-b-gray-900"
|
||||||
|
></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ActivityPage;
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
|
||||||
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
|
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||||
|
import {
|
||||||
|
RiTextWrap,
|
||||||
|
RiAlignJustify,
|
||||||
|
RiFontSize,
|
||||||
|
RiMenuSearchLine,
|
||||||
|
RiMenuSearchFill,
|
||||||
|
RiCloseCircleFill,
|
||||||
|
} from "react-icons/ri";
|
||||||
|
import { useTheme } from "../contexts/ThemeProvider";
|
||||||
|
|
||||||
|
const LogViewer = () => {
|
||||||
|
const { proxyLogs, upstreamLogs } = useAPI();
|
||||||
|
const { screenWidth } = useTheme();
|
||||||
|
const direction = screenWidth === "xs" || screenWidth === "sm" ? "vertical" : "horizontal";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PanelGroup direction={direction} className="gap-2" autoSaveId="logviewer-panel-group">
|
||||||
|
<Panel id="proxy" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||||
|
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
|
||||||
|
</Panel>
|
||||||
|
<PanelResizeHandle
|
||||||
|
className={
|
||||||
|
direction === "horizontal"
|
||||||
|
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||||
|
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<Panel id="upstream" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
|
||||||
|
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||||
|
</Panel>
|
||||||
|
</PanelGroup>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
interface LogPanelProps {
|
||||||
|
id: string;
|
||||||
|
title: string;
|
||||||
|
logData: string;
|
||||||
|
}
|
||||||
|
export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
|
||||||
|
const [filterRegex, setFilterRegex] = useState("");
|
||||||
|
const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">(
|
||||||
|
`logPanel-${id}-fontSize`,
|
||||||
|
"normal"
|
||||||
|
);
|
||||||
|
const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false);
|
||||||
|
const [showFilter, setShowFilter] = usePersistentState(`logPanel-${id}-showFilter`, false);
|
||||||
|
|
||||||
|
const textWrapClass = useMemo(() => {
|
||||||
|
return wrapText ? "whitespace-pre-wrap" : "whitespace-pre";
|
||||||
|
}, [wrapText]);
|
||||||
|
|
||||||
|
const toggleFontSize = useCallback(() => {
|
||||||
|
setFontSize((prev) => {
|
||||||
|
switch (prev) {
|
||||||
|
case "xxs":
|
||||||
|
return "xs";
|
||||||
|
case "xs":
|
||||||
|
return "small";
|
||||||
|
case "small":
|
||||||
|
return "normal";
|
||||||
|
case "normal":
|
||||||
|
return "xxs";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const toggleWrapText = useCallback(() => {
|
||||||
|
setTextWrap((prev) => !prev);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const toggleFilter = useCallback(() => {
|
||||||
|
if (showFilter) {
|
||||||
|
setShowFilter(false);
|
||||||
|
setFilterRegex(""); // Clear filter when closing
|
||||||
|
} else {
|
||||||
|
setShowFilter(true);
|
||||||
|
}
|
||||||
|
}, [filterRegex, setFilterRegex, showFilter]);
|
||||||
|
|
||||||
|
const fontSizeClass = useMemo(() => {
|
||||||
|
switch (fontSize) {
|
||||||
|
case "xxs":
|
||||||
|
return "text-[0.5rem]"; // 0.5rem (8px)
|
||||||
|
case "xs":
|
||||||
|
return "text-[0.75rem]"; // 0.75rem (12px)
|
||||||
|
case "small":
|
||||||
|
return "text-[0.875rem]"; // 0.875rem (14px)
|
||||||
|
case "normal":
|
||||||
|
return "text-base"; // 1rem (16px)
|
||||||
|
}
|
||||||
|
}, [fontSize]);
|
||||||
|
|
||||||
|
const filteredLogs = useMemo(() => {
|
||||||
|
if (!filterRegex) return logData;
|
||||||
|
try {
|
||||||
|
const regex = new RegExp(filterRegex, "i");
|
||||||
|
const lines = logData.split("\n");
|
||||||
|
const filtered = lines.filter((line) => regex.test(line));
|
||||||
|
return filtered.join("\n");
|
||||||
|
} catch (e) {
|
||||||
|
return logData; // Return unfiltered if regex is invalid
|
||||||
|
}
|
||||||
|
}, [logData, filterRegex]);
|
||||||
|
|
||||||
|
// auto scroll to bottom
|
||||||
|
const preTagRef = useRef<HTMLPreElement>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
if (!preTagRef.current) return;
|
||||||
|
preTagRef.current.scrollTop = preTagRef.current.scrollHeight;
|
||||||
|
}, [filteredLogs]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="rounded-lg overflow-hidden flex flex-col bg-gray-950/5 dark:bg-white/10 h-full p-1">
|
||||||
|
<div className="p-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h3 className="m-0 text-lg p-0">{title}</h3>
|
||||||
|
|
||||||
|
<div className="flex gap-2 items-center">
|
||||||
|
<button className="btn border-0" onClick={toggleFontSize}>
|
||||||
|
<RiFontSize />
|
||||||
|
</button>
|
||||||
|
<button className="btn border-0" onClick={toggleWrapText}>
|
||||||
|
{wrapText ? <RiTextWrap /> : <RiAlignJustify />}
|
||||||
|
</button>
|
||||||
|
<button className="btn border-0" onClick={toggleFilter}>
|
||||||
|
{showFilter ? <RiMenuSearchFill /> : <RiMenuSearchLine />}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Filtering Options - Full width on mobile, normal on desktop */}
|
||||||
|
{showFilter && (
|
||||||
|
<div className="mt-2 w-full">
|
||||||
|
<div className="flex gap-2 items-center w-full">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
className="w-full text-sm border border-gray-950/10 dark:border-white/5 p-2 rounded outline-none"
|
||||||
|
placeholder="Filter logs..."
|
||||||
|
value={filterRegex}
|
||||||
|
onChange={(e) => setFilterRegex(e.target.value)}
|
||||||
|
/>
|
||||||
|
<button className="pl-2" onClick={() => setFilterRegex("")}>
|
||||||
|
<RiCloseCircleFill size="24" />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="rounded-lg bg-background font-mono text-sm flex-1 overflow-hidden">
|
||||||
|
<pre ref={preTagRef} className={`${textWrapClass} ${fontSizeClass} h-full overflow-auto p-4`}>
|
||||||
|
{filteredLogs}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
export default LogViewer;
|
||||||
@@ -0,0 +1,236 @@
|
|||||||
|
import { useState, useCallback, useMemo } from "react";
|
||||||
|
import { useAPI } from "../contexts/APIProvider";
|
||||||
|
import { LogPanel } from "./LogViewer";
|
||||||
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
|
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||||
|
import { useTheme } from "../contexts/ThemeProvider";
|
||||||
|
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine, RiMenuFill } from "react-icons/ri";
|
||||||
|
|
||||||
|
export default function ModelsPage() {
|
||||||
|
const { isNarrow } = useTheme();
|
||||||
|
const direction = isNarrow ? "vertical" : "horizontal";
|
||||||
|
const { upstreamLogs } = useAPI();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PanelGroup direction={direction} className="gap-2" autoSaveId={"models-panel-group"}>
|
||||||
|
<Panel id="models" defaultSize={50} minSize={isNarrow ? 0 : 25} maxSize={100} collapsible={isNarrow}>
|
||||||
|
<ModelsPanel />
|
||||||
|
</Panel>
|
||||||
|
|
||||||
|
<PanelResizeHandle
|
||||||
|
className={
|
||||||
|
direction === "horizontal"
|
||||||
|
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
|
||||||
|
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<Panel collapsible={true} defaultSize={50} minSize={0}>
|
||||||
|
<div className="flex flex-col h-full space-y-4">
|
||||||
|
{direction === "horizontal" && <StatsPanel />}
|
||||||
|
<div className="flex-1 min-h-0">
|
||||||
|
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Panel>
|
||||||
|
</PanelGroup>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function ModelsPanel() {
|
||||||
|
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
|
||||||
|
const { isNarrow } = useTheme();
|
||||||
|
const [isUnloading, setIsUnloading] = useState(false);
|
||||||
|
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||||
|
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
||||||
|
const [menuOpen, setMenuOpen] = useState(false);
|
||||||
|
|
||||||
|
const filteredModels = useMemo(() => {
|
||||||
|
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||||
|
}, [models, showUnlisted]);
|
||||||
|
|
||||||
|
const handleUnloadAllModels = useCallback(async () => {
|
||||||
|
setIsUnloading(true);
|
||||||
|
try {
|
||||||
|
await unloadAllModels();
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
} finally {
|
||||||
|
setTimeout(() => {
|
||||||
|
setIsUnloading(false);
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
}, [unloadAllModels]);
|
||||||
|
|
||||||
|
const toggleIdorName = useCallback(() => {
|
||||||
|
setShowIdorName((prev) => (prev === "name" ? "id" : "name"));
|
||||||
|
}, [showIdorName]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="card h-full flex flex-col">
|
||||||
|
<div className="shrink-0">
|
||||||
|
<div className="flex justify-between items-baseline">
|
||||||
|
<h2 className={isNarrow ? "text-xl" : ""}>Models</h2>
|
||||||
|
{isNarrow && (
|
||||||
|
<div className="relative">
|
||||||
|
<button className="btn text-base flex items-center gap-2 py-1" onClick={() => setMenuOpen(!menuOpen)}>
|
||||||
|
<RiMenuFill size="20" />
|
||||||
|
</button>
|
||||||
|
{menuOpen && (
|
||||||
|
<div className="absolute right-0 mt-2 w-48 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-20">
|
||||||
|
<button
|
||||||
|
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onClick={() => {
|
||||||
|
toggleIdorName();
|
||||||
|
setMenuOpen(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "Show Name" : "Show ID"}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onClick={() => {
|
||||||
|
setShowUnlisted(!showUnlisted);
|
||||||
|
setMenuOpen(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{showUnlisted ? <RiEyeOffFill size="20" /> : <RiEyeFill size="20" />}{" "}
|
||||||
|
{showUnlisted ? "Hide Unlisted" : "Show Unlisted"}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
className="w-full text-left px-4 py-2 hover:bg-secondary-hover flex items-center gap-2"
|
||||||
|
onClick={() => {
|
||||||
|
handleUnloadAllModels();
|
||||||
|
setMenuOpen(false);
|
||||||
|
}}
|
||||||
|
disabled={isUnloading}
|
||||||
|
>
|
||||||
|
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
{!isNarrow && (
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<button
|
||||||
|
className="btn text-base flex items-center gap-2"
|
||||||
|
onClick={toggleIdorName}
|
||||||
|
style={{ lineHeight: "1.2" }}
|
||||||
|
>
|
||||||
|
<RiSwapBoxFill size="20" /> {showIdorName === "id" ? "ID" : "Name"}
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<button
|
||||||
|
className="btn text-base flex items-center gap-2"
|
||||||
|
onClick={() => setShowUnlisted(!showUnlisted)}
|
||||||
|
style={{ lineHeight: "1.2" }}
|
||||||
|
>
|
||||||
|
{showUnlisted ? <RiEyeFill size="20" /> : <RiEyeOffFill size="20" />} unlisted
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
className="btn text-base flex items-center gap-2"
|
||||||
|
onClick={handleUnloadAllModels}
|
||||||
|
disabled={isUnloading}
|
||||||
|
>
|
||||||
|
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex-1 overflow-y-auto">
|
||||||
|
<table className="w-full">
|
||||||
|
<thead className="sticky top-0 bg-card z-10">
|
||||||
|
<tr className="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||||
|
<th>{showIdorName === "id" ? "Model ID" : "Name"}</th>
|
||||||
|
<th></th>
|
||||||
|
<th>State</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{filteredModels.map((model) => (
|
||||||
|
<tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200">
|
||||||
|
<td className={`${model.unlisted ? "text-txtsecondary" : ""}`}>
|
||||||
|
<a href={`/upstream/${model.id}/`} className="font-semibold" target="_blank">
|
||||||
|
{showIdorName === "id" ? model.id : model.name !== "" ? model.name : model.id}
|
||||||
|
</a>
|
||||||
|
|
||||||
|
{!!model.description && (
|
||||||
|
<p className={model.unlisted ? "text-opacity-70" : ""}>
|
||||||
|
<em>{model.description}</em>
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
<td className="w-12">
|
||||||
|
{model.state === "stopped" ? (
|
||||||
|
<button className="btn btn--sm" onClick={() => loadModel(model.id)}>
|
||||||
|
Load
|
||||||
|
</button>
|
||||||
|
) : (
|
||||||
|
<button
|
||||||
|
className="btn btn--sm"
|
||||||
|
onClick={() => unloadSingleModel(model.id)}
|
||||||
|
disabled={model.state !== "ready"}
|
||||||
|
>
|
||||||
|
Unload
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
<td className="w-20">
|
||||||
|
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function StatsPanel() {
|
||||||
|
const { metrics } = useAPI();
|
||||||
|
|
||||||
|
const [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond] = useMemo(() => {
|
||||||
|
const totalRequests = metrics.length;
|
||||||
|
if (totalRequests === 0) {
|
||||||
|
return [0, 0, 0];
|
||||||
|
}
|
||||||
|
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||||
|
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||||
|
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
|
||||||
|
return [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond];
|
||||||
|
}, [metrics]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="card">
|
||||||
|
<div className="rounded-lg overflow-hidden border border-gray-200 dark:border-white/10">
|
||||||
|
<table className="w-full">
|
||||||
|
<thead>
|
||||||
|
<tr className="border-b border-gray-200 dark:border-white/10 text-right">
|
||||||
|
<th>Requests</th>
|
||||||
|
<th className="border-l border-gray-200 dark:border-white/10">Processed</th>
|
||||||
|
<th className="border-l border-gray-200 dark:border-white/10">Generated</th>
|
||||||
|
<th className="border-l border-gray-200 dark:border-white/10">Tokens/Sec</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr className="text-right">
|
||||||
|
<td className="border-r border-gray-200 dark:border-white/10">{totalRequests}</td>
|
||||||
|
<td className="border-r border-gray-200 dark:border-white/10">
|
||||||
|
{new Intl.NumberFormat().format(totalInputTokens)}
|
||||||
|
</td>
|
||||||
|
<td className="border-r border-gray-200 dark:border-white/10">
|
||||||
|
{new Intl.NumberFormat().format(totalOutputTokens)}
|
||||||
|
</td>
|
||||||
|
<td>{avgTokensPerSecond}</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
/// <reference types="vite/client" />
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
||||||
|
"target": "ES2020",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
|
||||||
|
/* Bundler mode */
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"verbatimModuleSyntax": true,
|
||||||
|
"moduleDetection": "force",
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
|
||||||
|
/* Linting */
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": true,
|
||||||
|
"noUnusedParameters": true,
|
||||||
|
"erasableSyntaxOnly": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"noUncheckedSideEffectImports": true
|
||||||
|
},
|
||||||
|
"include": ["src"]
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"files": [],
|
||||||
|
"references": [
|
||||||
|
{ "path": "./tsconfig.app.json" },
|
||||||
|
{ "path": "./tsconfig.node.json" }
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
||||||
|
"target": "ES2022",
|
||||||
|
"lib": ["ES2023"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
|
||||||
|
/* Bundler mode */
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"verbatimModuleSyntax": true,
|
||||||
|
"moduleDetection": "force",
|
||||||
|
"noEmit": true,
|
||||||
|
|
||||||
|
/* Linting */
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": true,
|
||||||
|
"noUnusedParameters": true,
|
||||||
|
"erasableSyntaxOnly": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"noUncheckedSideEffectImports": true
|
||||||
|
},
|
||||||
|
"include": ["vite.config.ts"]
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
import { defineConfig } from "vite";
|
||||||
|
import react from "@vitejs/plugin-react";
|
||||||
|
import tailwindcss from "@tailwindcss/vite";
|
||||||
|
|
||||||
|
// https://vite.dev/config/
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react(), tailwindcss()],
|
||||||
|
base: "/ui/",
|
||||||
|
build: {
|
||||||
|
outDir: "../proxy/ui_dist",
|
||||||
|
assetsDir: "assets",
|
||||||
|
},
|
||||||
|
server: {
|
||||||
|
proxy: {
|
||||||
|
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
||||||
|
"/logs": "http://localhost:8080",
|
||||||
|
"/upstream": "http://localhost:8080",
|
||||||
|
"/unload": "http://localhost:8080",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||