Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6439ab1515 | |||
| f94226122c | |||
| 7493618fdc | |||
| 205efd40a1 | |||
| 14207f8492 | |||
| 4e850c2834 | |||
| 75fced579e | |||
| b73f367f22 | |||
| 8f2137c72b | |||
| 124007cc98 | |||
| eb5bfff0b0 | |||
| 3edb180c08 | |||
| 66d555e625 | |||
| 4f863fd9fc | |||
| 267c030457 | |||
| c19309fe7e | |||
| 4413881b2d | |||
| 8df5e8563b | |||
| 7931212d3e | |||
| 3dc36032fb | |||
| addb98646f | |||
| 37d74efc2d | |||
| 22e098ac8b | |||
| 9864f9f517 | |||
| 53b32f3601 |
@@ -8,8 +8,15 @@ reviews:
|
||||
poem: false
|
||||
review_status: true
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
enabled: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
issue_enrichment:
|
||||
planning:
|
||||
enabled: false
|
||||
|
||||
@@ -10,17 +10,36 @@ on:
|
||||
# Allows manual triggering of the workflow
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on workflow file changes (without pushing)
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/containers.yml'
|
||||
- 'docker/build-container.sh'
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [intel, cuda, vulkan, cpu, musa]
|
||||
platform: [intel, cuda, vulkan, cpu, musa, rocm]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
if: matrix.platform == 'rocm'
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -31,7 +50,7 @@ jobs:
|
||||
- name: Run build-container
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} true
|
||||
run: ./docker/build-container.sh ${{ matrix.platform }} ${{ github.event_name != 'push' }}
|
||||
|
||||
# note make sure mostlygeek/llama-swap has admin rights to the llama-swap package
|
||||
# see: https://github.com/actions/delete-package-versions/issues/74
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
# Project: llama-swap
|
||||
|
||||
## Project Description:
|
||||
|
||||
llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server.
|
||||
@@ -7,37 +5,45 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and react for UI (ui/)
|
||||
|
||||
## Testing
|
||||
|
||||
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests.
|
||||
- typescript, vite and react for UI (located in ui/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
### Plan Improvements
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use `gh` to create PRs and load issues
|
||||
- do include Co-Authored-By or created by when committing changes or creating PRs
|
||||
- keep PR descriptions short and focused on changes.
|
||||
- never include a test plan
|
||||
|
||||
Work plans are located in ai-plans/. Plans written by the user may be incomplete, contain inconsistencies or errors.
|
||||
## Testing
|
||||
|
||||
When the user asks to improve a plan follow these guidelines for expanding and improving it.
|
||||
- Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
|
||||
- Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
|
||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||
|
||||
- Identify any inconsistencies.
|
||||
- Expand plans out to be detailed specification of requirements and changes to be made.
|
||||
- Plans should have at least these sections:
|
||||
- Title - very short, describes changes
|
||||
- Overview: A more detailed summary of goal and outcomes desired
|
||||
- Design Requirements: Detailed descriptions of what needs to be done
|
||||
- Testing Plan: Tests to be implemented
|
||||
- Checklist: A detailed list of changes to be made
|
||||
### Commit message example format:
|
||||
|
||||
Look for "plan expansion" as explicit instructions to improve a plan.
|
||||
```
|
||||
proxy: add new feature
|
||||
|
||||
### Implementation of plans
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
When the user says "paint it", respond with "commencing automated assembly". Then implement the changes as described by the plan. Update the checklist as you complete items.
|
||||
- key change 1
|
||||
- key change 2
|
||||
- key change 3
|
||||
|
||||
## General Rules
|
||||
fixes #123
|
||||
```
|
||||
|
||||
- when summarizing changes only include details that require further action (action items)
|
||||
- when there are no action items, just say "Done."
|
||||
## Code Reviews
|
||||
|
||||
- use three levels High, Medium, Low severity
|
||||
- label each discovered issue with a label like H1, M2, L3 respectively
|
||||
- High severity are must fix issues (security, race conditions, critical bugs)
|
||||
- Medium severity are recommended improvements (coding style, missing functionality, inconsistencies)
|
||||
- Low severity are nice to have changes and nits
|
||||
- Include a suggestion with each discovered item
|
||||
- Limit your code review to three items with the highest priority first
|
||||
- Double check your discovered items and recommended remediations
|
||||
|
||||
@@ -18,11 +18,16 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
- ✅ OpenAI API supported endpoints:
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/responses`
|
||||
- `v1/embeddings`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- `v1/audio/voices`
|
||||
- `v1/images/generations`
|
||||
- `v1/images/edits`
|
||||
- ✅ Anthropic API supported endpoints:
|
||||
- `v1/messages`
|
||||
- `v1/messages/count_tokens`
|
||||
- ✅ llama-server (llama.cpp) supported endpoints
|
||||
- `v1/rerank`, `v1/reranking`, `/rerank`
|
||||
- `/infill` - for code infilling
|
||||
@@ -34,6 +39,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
|
||||
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
|
||||
- `/log` - remote log monitoring
|
||||
- `/health` - just returns "OK"
|
||||
- ✅ API Key support - define keys to restrict access to API endpoints
|
||||
- ✅ Customizable
|
||||
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
|
||||
- Automatic unloading of models after timeout by setting a `ttl`
|
||||
|
||||
+79
-1
@@ -188,11 +188,17 @@
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Used for server-side enforcement of sampling parameters."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests. Useful for enforcing specific parameter values. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings. Only stripParams is supported."
|
||||
"description": "Dictionary of filter settings. Supports stripParams and setParams."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
@@ -273,6 +279,78 @@
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
|
||||
},
|
||||
"logToStdout": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"proxy",
|
||||
"upstream",
|
||||
"both",
|
||||
"none"
|
||||
],
|
||||
"default": "proxy",
|
||||
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
|
||||
},
|
||||
"apiKeys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"default": [],
|
||||
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
|
||||
},
|
||||
"peers": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"proxy",
|
||||
"models"
|
||||
],
|
||||
"properties": {
|
||||
"proxy": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
|
||||
},
|
||||
"apiKey": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"description": "A list of models served by the peer."
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stripParams": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"pattern": "^[a-zA-Z0-9_, ]*$",
|
||||
"description": "Comma separated list of parameters to remove from the request. Useful for removing parameters that the peer doesn't support."
|
||||
},
|
||||
"setParams": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"default": {},
|
||||
"description": "Dictionary of parameters to set/override in requests to this peer. Useful for injecting provider-specific settings. Protected params like 'model' cannot be overridden. Values can be strings, numbers, booleans, arrays, or objects."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {},
|
||||
"description": "Dictionary of filter settings for peer requests. Supports stripParams and setParams."
|
||||
}
|
||||
}
|
||||
},
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+85
-1
@@ -80,6 +80,9 @@ includeAliasesInList: false
|
||||
# - macro names must not be a reserved name: PORT or MODEL_ID
|
||||
# - macro values can be numbers, bools, or strings
|
||||
# - macros can contain other macros, but they must be defined before they are used
|
||||
# - environment variables can be referenced with ${env.VAR_NAME} syntax
|
||||
# - env macros are substituted first, before regular macros
|
||||
# - if the env var is not set, config loading will fail with an error
|
||||
macros:
|
||||
# Example of a multi-line macro
|
||||
"latest-llama": >
|
||||
@@ -92,6 +95,24 @@ macros:
|
||||
# but they must be previously declared.
|
||||
"default_args": "--ctx-size ${default_ctx}"
|
||||
|
||||
# Example of environment variable macros
|
||||
# - ${env.VAR_NAME} pulls the value from the system environment
|
||||
# - useful for paths, secrets, or machine-specific configuration
|
||||
"models_dir": "${env.HOME}/models"
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# tip, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
|
||||
# use environment variable macros to keep secrets out of the config
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
@@ -175,7 +196,7 @@ models:
|
||||
|
||||
# filters: a dictionary of filter settings
|
||||
# - optional, default: empty dictionary
|
||||
# - only stripParams is currently supported
|
||||
# - same capabilities as peer filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
@@ -185,6 +206,16 @@ models:
|
||||
# - recommended to stick to sampling parameters
|
||||
stripParams: "temperature, top_p, top_k"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for enforcing specific parameter values
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce specific sampling parameters
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
|
||||
# metadata: a dictionary of arbitrary values that are included in /v1/models
|
||||
# - optional, default: empty dictionary
|
||||
# - while metadata can contains complex types it is recommended to keep it simple
|
||||
@@ -331,3 +362,56 @@ hooks:
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
# - can be a string or a macro
|
||||
apiKey: ${env.OPENROUTER_API_KEY}
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
# filters: a dictionary of filter settings for peer requests
|
||||
# - optional, default: empty dictionary
|
||||
# - same capabilities as model filters (stripParams, setParams)
|
||||
filters:
|
||||
# stripParams: a comma separated list of parameters to remove from the request
|
||||
# - optional, default: ""
|
||||
# - useful for removing parameters that the peer doesn't support
|
||||
# - the `model` parameter can never be removed
|
||||
stripParams: "temperature, top_p"
|
||||
|
||||
# setParams: a dictionary of parameters to set/override in requests to this peer
|
||||
# - optional, default: empty dictionary
|
||||
# - useful for injecting provider-specific settings like data retention policies
|
||||
# - protected params like "model" cannot be overridden
|
||||
# - values can be strings, numbers, booleans, arrays, or objects
|
||||
setParams:
|
||||
# Example: enforce zero-data-retention for OpenRouter
|
||||
provider:
|
||||
data_collection: "deny"
|
||||
zdr: true
|
||||
|
||||
+79
-14
@@ -2,21 +2,37 @@
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
# use this to test locally, example:
|
||||
# GITHUB_TOKEN=$(gh auth token) LOG_DEBUG=1 DEBUG_ABORT_BUILD=1 ./docker/build-container.sh rocm
|
||||
# you need read:package scope on the token. Generate a personal access token with
|
||||
# the scopes: gist, read:org, repo, write:packages
|
||||
# then: gh auth login (and copy/paste the new token)
|
||||
|
||||
log_debug() {
|
||||
if [ "$LOG_DEBUG" = "1" ]; then
|
||||
echo "[DEBUG] $*"
|
||||
fi
|
||||
}
|
||||
|
||||
log_info() {
|
||||
echo "[INFO] $*"
|
||||
}
|
||||
|
||||
ARCH=$1
|
||||
PUSH_IMAGES=${2:-false}
|
||||
|
||||
# List of allowed architectures
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu")
|
||||
ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu" "rocm")
|
||||
|
||||
# Check if ARCH is in the allowed list
|
||||
if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then
|
||||
echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
log_info "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if GITHUB_TOKEN is set and not empty
|
||||
if [[ -z "$GITHUB_TOKEN" ]]; then
|
||||
echo "Error: GITHUB_TOKEN is not set or is empty."
|
||||
log_info "Error: GITHUB_TOKEN is not set or is empty."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -32,25 +48,74 @@ LS_REPO=${GITHUB_REPOSITORY:-mostlygeek/llama-swap}
|
||||
# have to strip out the 'v' due to .tar.gz file naming
|
||||
LS_VER=$(curl -s https://api.github.com/repos/${LS_REPO}/releases/latest | jq -r .tag_name | sed 's/v//')
|
||||
|
||||
# Fetches the most recent llama.cpp tag matching the given prefix
|
||||
# Handles pagination to search beyond the first 100 results
|
||||
# $1 - tag_prefix (e.g., "server" or "server-vulkan")
|
||||
# Returns: the version number extracted from the tag
|
||||
fetch_llama_tag() {
|
||||
local tag_prefix=$1
|
||||
local page=1
|
||||
local per_page=100
|
||||
|
||||
while true; do
|
||||
log_debug "Fetching page $page for tag prefix: $tag_prefix"
|
||||
|
||||
local response=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions?per_page=${per_page}&page=${page}")
|
||||
|
||||
# Check for API errors
|
||||
if echo "$response" | jq -e '.message' > /dev/null 2>&1; then
|
||||
local error_msg=$(echo "$response" | jq -r '.message')
|
||||
log_info "GitHub API error: $error_msg"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Check if response is empty array (no more pages)
|
||||
if [ "$(echo "$response" | jq 'length')" -eq 0 ]; then
|
||||
log_debug "No more pages (empty response)"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract matching tag from this page
|
||||
local found_tag=$(echo "$response" | jq -r \
|
||||
".[] | select(.metadata.container.tags[]? | startswith(\"$tag_prefix\")) | .metadata.container.tags[] | select(startswith(\"$tag_prefix\"))" \
|
||||
| sort -r | head -n1)
|
||||
|
||||
if [ -n "$found_tag" ]; then
|
||||
log_debug "Found tag: $found_tag on page $page"
|
||||
echo "$found_tag" | awk -F '-' '{print $NF}'
|
||||
return 0
|
||||
fi
|
||||
|
||||
page=$((page + 1))
|
||||
|
||||
# Safety limit to prevent infinite loops
|
||||
if [ $page -gt 50 ]; then
|
||||
log_info "Reached pagination safety limit (50 pages)"
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
if [ "$ARCH" == "cpu" ]; then
|
||||
# cpu only containers just use the server tag
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r '.[] | select(.metadata.container.tags[] | startswith("server")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
LCPP_TAG=$(fetch_llama_tag "server")
|
||||
BASE_TAG=server-${LCPP_TAG}
|
||||
else
|
||||
LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \
|
||||
| jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \
|
||||
| sort -r | head -n1 | awk -F '-' '{print $3}')
|
||||
LCPP_TAG=$(fetch_llama_tag "server-${ARCH}")
|
||||
BASE_TAG=server-${ARCH}-${LCPP_TAG}
|
||||
fi
|
||||
|
||||
# Abort if LCPP_TAG is empty.
|
||||
if [[ -z "$LCPP_TAG" ]]; then
|
||||
echo "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
log_info "Abort: Could not find llama-server container for arch: $ARCH"
|
||||
exit 1
|
||||
else
|
||||
log_info "LCPP_TAG: $LCPP_TAG"
|
||||
fi
|
||||
|
||||
if [[ ! -z "$DEBUG_ABORT_BUILD" ]]; then
|
||||
log_info "Abort: DEBUG_ABORT_BUILD set"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for CONTAINER_TYPE in non-root root; do
|
||||
@@ -68,7 +133,7 @@ for CONTAINER_TYPE in non-root root; do
|
||||
USER_HOME=/app
|
||||
fi
|
||||
|
||||
echo "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
log_info "Building $CONTAINER_TYPE $CONTAINER_TAG $LS_VER"
|
||||
docker build -f llama-swap.Containerfile --build-arg BASE_TAG=${BASE_TAG} --build-arg LS_VER=${LS_VER} --build-arg UID=${USER_UID} \
|
||||
--build-arg LS_REPO=${LS_REPO} --build-arg GID=${USER_GID} --build-arg USER_HOME=${USER_HOME} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} \
|
||||
--build-arg BASE_IMAGE=${BASE_IMAGE} .
|
||||
|
||||
+43
-1
@@ -86,7 +86,7 @@ llama-swap supports many more features to customize how you want to manage your
|
||||
## Full Configuration Example
|
||||
|
||||
> [!NOTE]
|
||||
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples.
|
||||
> Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
|
||||
|
||||
```yaml
|
||||
# add this modeline for validation in vscode
|
||||
@@ -161,6 +161,16 @@ sendLoadingState: true
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# apiKeys: require an API key when making requests to inference endpoints
|
||||
# - optional, default: []
|
||||
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
|
||||
# - each key is a non-empty string
|
||||
apiKeys:
|
||||
- "sk-hunter2"
|
||||
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
|
||||
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
|
||||
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
@@ -422,4 +432,36 @@ hooks:
|
||||
# otherwise models will be loaded and swapped out
|
||||
preload:
|
||||
- "llama"
|
||||
|
||||
# peers: a dictionary of remote peers and models they provide
|
||||
# - optional, default empty dictionary
|
||||
# - peers can be another llama-swap
|
||||
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
|
||||
peers:
|
||||
# keys is the peer'd ID
|
||||
llama-swap-peer:
|
||||
# proxy: a valid base URL to proxy requests to
|
||||
# - required
|
||||
# - requested path to llama-swap will be appended to the end of the proxy value
|
||||
proxy: http://192.168.1.23
|
||||
# models: a list of models served by the peer
|
||||
# - required
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
- embeddings/model_c
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
# apiKey: a string key to be injected into the request
|
||||
# - optional, default: ""
|
||||
# - if blank, no key will be added to the request
|
||||
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
|
||||
apiKey: sk-your-openrouter-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- qwen/qwen3-235b-a22b-2507
|
||||
- deepseek/deepseek-v3.2
|
||||
- z-ai/glm-4.7
|
||||
- moonshotai/kimi-k2-0905
|
||||
- minimax/minimax-m2.1
|
||||
```
|
||||
|
||||
+142
-61
@@ -87,6 +87,7 @@ type GroupConfig struct {
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
@@ -143,6 +144,12 @@ type Config struct {
|
||||
|
||||
// present aliases to /v1/models OpenAI API listing
|
||||
IncludeAliasesInList bool `yaml:"includeAliasesInList"`
|
||||
|
||||
// support API keys, see issue #433, #50, #251
|
||||
RequiredAPIKeys []string `yaml:"apiKeys"`
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -177,8 +184,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// default configuration values
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
@@ -187,13 +202,11 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
// set a minimum of 15 seconds
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
@@ -218,55 +231,46 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
/* check macro constraint rules:
|
||||
|
||||
- name must fit the regex ^[a-zA-Z0-9_-]+$
|
||||
- names must be less than 64 characters (no reason, just cause)
|
||||
- name can not be any reserved macros: PORT, MODEL_ID
|
||||
- macro values must be less than 1024 characters
|
||||
*/
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
|
||||
// Strip comments from command fields before macro expansion
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// validate model macros
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Merge global config and model macros. Model macros take precedence
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
|
||||
// Add global macros first
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (can override global)
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
// Remove any existing global macro with same name
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry // Override
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -276,23 +280,20 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
|
||||
// This allows later macros to reference earlier ones
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
// Substitute in command fields
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in metadata (recursive)
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
@@ -301,18 +302,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Final pass: check if PORT macro is needed after macro expansion
|
||||
// ${PORT} is a resource on the local machine so a new port is only allocated
|
||||
// if it is required in either cmd or proxy keys
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort { // either has it
|
||||
if !cmdHasPort && proxyHasPort { // but both don't have it
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
// Add PORT macro and substitute it
|
||||
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
@@ -320,10 +317,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
|
||||
// Substitute PORT in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
var err error
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
@@ -333,7 +328,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
@@ -347,35 +342,27 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // this is ok, has to be replaced by process later
|
||||
continue // replaced at runtime
|
||||
}
|
||||
// Reserved macros are always valid (they should have been substituted already)
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
// Any other macro is unknown
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unknown macros in metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the proxy URL.
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf(
|
||||
"model %s: invalid proxy URL: %w", modelId, err,
|
||||
)
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
// if sendLoadingState is nil, set it to the global config value
|
||||
// see #366
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState // copy it
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
@@ -383,18 +370,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
// Check for duplicates within this group
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
// Check if member is used in another group
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
@@ -402,7 +388,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// clean up hooks preload
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
@@ -414,10 +400,56 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -548,20 +580,26 @@ func validateMacro(name string, value any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
|
||||
func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -569,7 +607,7 @@ func validateMetadataForUnknownMacros(value any, modelId string) error {
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -628,3 +666,46 @@ func substituteMacroInValue(value any, macroName string, macroValue any) (any, e
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values
|
||||
// Returns error if any env var is not set or contains invalid characters
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
result := s
|
||||
matches := envMacroRegex.FindAllStringSubmatch(s, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -761,3 +761,551 @@ models:
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
content: `apiKeys: [""]`,
|
||||
expectedErr: "empty api key found in apiKeys",
|
||||
},
|
||||
{
|
||||
name: "blank spaces only",
|
||||
content: `apiKeys: [" "]`,
|
||||
expectedErr: "api key cannot contain spaces: ` `",
|
||||
},
|
||||
{
|
||||
name: "contains leading space",
|
||||
content: `apiKeys: [" key123"]`,
|
||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||
},
|
||||
{
|
||||
name: "contains trailing space",
|
||||
content: `apiKeys: ["key123 "]`,
|
||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||
},
|
||||
{
|
||||
name: "contains middle space",
|
||||
content: `apiKeys: ["key 123"]`,
|
||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||
},
|
||||
{
|
||||
name: "empty in list with valid keys",
|
||||
content: `apiKeys: ["valid-key", "", "another-key"]`,
|
||||
expectedErr: "empty api key found in apiKeys",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, tt.expectedErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_APIKeys_EnvMacros(t *testing.T) {
|
||||
t.Run("env substitution in apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "secret-key-123")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_API_KEY}"]`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"secret-key-123"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("multiple env substitutions in apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY_1", "key-one")
|
||||
t.Setenv("TEST_API_KEY_2", "key-two")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_API_KEY_1}", "${env.TEST_API_KEY_2}", "static-key"]`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"key-one", "key-two", "static-key"}, config.RequiredAPIKeys)
|
||||
})
|
||||
|
||||
t.Run("missing env var in apiKeys", func(t *testing.T) {
|
||||
content := `apiKeys: ["${env.NONEXISTENT_API_KEY}"]`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
// With string-level env substitution, error only includes var name
|
||||
assert.Contains(t, err.Error(), "NONEXISTENT_API_KEY")
|
||||
})
|
||||
|
||||
t.Run("env substitution results in empty key", func(t *testing.T) {
|
||||
t.Setenv("TEST_EMPTY_KEY", "")
|
||||
|
||||
content := `apiKeys: ["${env.TEST_EMPTY_KEY}"]`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "empty api key found in apiKeys", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_EnvMacros(t *testing.T) {
|
||||
t.Run("basic env substitution in cmd", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_PATH", "/opt/models")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "${env.TEST_MODEL_PATH}/llama-server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/opt/models/llama-server", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env substitution in multiple fields", func(t *testing.T) {
|
||||
t.Setenv("TEST_HOST", "myserver")
|
||||
t.Setenv("TEST_PORT", "9999")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --host ${env.TEST_HOST}"
|
||||
proxy: "http://${env.TEST_HOST}:${env.TEST_PORT}"
|
||||
checkEndpoint: "http://${env.TEST_HOST}/health"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --host myserver", config.Models["test"].Cmd)
|
||||
assert.Equal(t, "http://myserver:9999", config.Models["test"].Proxy)
|
||||
assert.Equal(t, "http://myserver/health", config.Models["test"].CheckEndpoint)
|
||||
})
|
||||
|
||||
t.Run("env in global macro value", func(t *testing.T) {
|
||||
t.Setenv("TEST_BASE_PATH", "/usr/local")
|
||||
|
||||
content := `
|
||||
macros:
|
||||
SERVER_PATH: "${env.TEST_BASE_PATH}/bin/server"
|
||||
models:
|
||||
test:
|
||||
cmd: "${SERVER_PATH} --port 8080"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/usr/local/bin/server --port 8080", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env in model-level macro value", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_DIR", "/models/llama")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
MODEL_FILE: "${env.TEST_MODEL_DIR}/model.gguf"
|
||||
cmd: "server --model ${MODEL_FILE}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --model /models/llama/model.gguf", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env in metadata", func(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "secret123")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
metadata:
|
||||
api_key: "${env.TEST_API_KEY}"
|
||||
nested:
|
||||
key: "${env.TEST_API_KEY}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "secret123", config.Models["test"].Metadata["api_key"])
|
||||
nested := config.Models["test"].Metadata["nested"].(map[string]any)
|
||||
assert.Equal(t, "secret123", nested["key"])
|
||||
})
|
||||
|
||||
t.Run("env in filters.stripParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_STRIP_PARAMS", "temperature,top_p")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
filters:
|
||||
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "temperature,top_p", config.Models["test"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("env in cmdStop", func(t *testing.T) {
|
||||
t.Setenv("TEST_KILL_SIGNAL", "SIGTERM")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --port ${PORT}"
|
||||
cmdStop: "kill -${env.TEST_KILL_SIGNAL} ${PID}"
|
||||
proxy: "http://localhost:${PORT}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, config.Models["test"].CmdStop, "-SIGTERM")
|
||||
})
|
||||
|
||||
t.Run("missing env var returns error", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "${env.UNDEFINED_VAR_12345}/server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_VAR_12345")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in global macro", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
PATH: "${env.UNDEFINED_GLOBAL_VAR}"
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_GLOBAL_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in model macro", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
macros:
|
||||
MY_PATH: "${env.UNDEFINED_MODEL_VAR}"
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_MODEL_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing env var in metadata", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server"
|
||||
proxy: "http://localhost:8080"
|
||||
metadata:
|
||||
key: "${env.UNDEFINED_META_VAR}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "UNDEFINED_META_VAR")
|
||||
assert.Contains(t, err.Error(), "not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env combined with regular macros", func(t *testing.T) {
|
||||
t.Setenv("TEST_ROOT", "/data")
|
||||
|
||||
content := `
|
||||
macros:
|
||||
MODEL_BASE: "${env.TEST_ROOT}/models"
|
||||
models:
|
||||
test:
|
||||
cmd: "server --model ${MODEL_BASE}/${MODEL_ID}.gguf"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --model /data/models/test.gguf", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("multiple env vars in same string", func(t *testing.T) {
|
||||
t.Setenv("TEST_USER", "admin")
|
||||
t.Setenv("TEST_PASS", "secret")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --auth ${env.TEST_USER}:${env.TEST_PASS}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "server --auth admin:secret", config.Models["test"].Cmd)
|
||||
})
|
||||
|
||||
t.Run("env value with newline is rejected", func(t *testing.T) {
|
||||
t.Setenv("TEST_MULTILINE", "line1\nline2")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --config ${env.TEST_MULTILINE}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "TEST_MULTILINE")
|
||||
assert.Contains(t, err.Error(), "newlines")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env value with carriage return is rejected", func(t *testing.T) {
|
||||
t.Setenv("TEST_CR", "line1\rline2")
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --config ${env.TEST_CR}"
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "TEST_CR")
|
||||
assert.Contains(t, err.Error(), "newlines")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("env value with quotes is escaped for YAML", func(t *testing.T) {
|
||||
t.Setenv("TEST_QUOTED", `value with "quotes"`)
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --arg \"${env.TEST_QUOTED}\""
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
// Quotes are escaped before YAML parsing, then YAML unescapes them
|
||||
// Final result preserves the original value with quotes
|
||||
assert.Contains(t, config.Models["test"].Cmd, `"quotes"`)
|
||||
})
|
||||
|
||||
t.Run("env value with backslash is escaped for YAML", func(t *testing.T) {
|
||||
t.Setenv("TEST_BACKSLASH", `path\to\file`)
|
||||
|
||||
content := `
|
||||
models:
|
||||
test:
|
||||
cmd: "server --path \"${env.TEST_BACKSLASH}\""
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
// Backslashes are escaped before YAML parsing, then YAML unescapes them
|
||||
// Final result preserves the original value with backslashes
|
||||
assert.Contains(t, config.Models["test"].Cmd, `path\to\file`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_PeerApiKey_EnvMacros(t *testing.T) {
|
||||
t.Run("env substitution in peer apiKey", func(t *testing.T) {
|
||||
t.Setenv("TEST_PEER_API_KEY", "sk-peer-secret-123")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${env.TEST_PEER_API_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-peer-secret-123", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("missing env var in peer apiKey", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${env.NONEXISTENT_PEER_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
// With string-level env substitution, error only includes var name
|
||||
assert.Contains(t, err.Error(), "NONEXISTENT_PEER_KEY")
|
||||
})
|
||||
|
||||
t.Run("static apiKey unchanged", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-static-key
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-static-key", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("multiple peers with env apiKeys", func(t *testing.T) {
|
||||
t.Setenv("TEST_PEER_KEY_1", "key-one")
|
||||
t.Setenv("TEST_PEER_KEY_2", "key-two")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
peer1:
|
||||
proxy: https://peer1.example.com
|
||||
apiKey: "${env.TEST_PEER_KEY_1}"
|
||||
models:
|
||||
- model-a
|
||||
peer2:
|
||||
proxy: https://peer2.example.com
|
||||
apiKey: "${env.TEST_PEER_KEY_2}"
|
||||
models:
|
||||
- model-b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "key-one", config.Peers["peer1"].ApiKey)
|
||||
assert.Equal(t, "key-two", config.Peers["peer2"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("global macro substitution in peer apiKey", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
API_KEY: sk-from-global-macro
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${API_KEY}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sk-from-global-macro", config.Peers["openrouter"].ApiKey)
|
||||
})
|
||||
|
||||
t.Run("global macro in peer filters.stripParams", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
STRIP_LIST: "temperature, top_p"
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
stripParams: "${STRIP_LIST}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "temperature, top_p", config.Peers["openrouter"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("global macro in peer filters.setParams", func(t *testing.T) {
|
||||
content := `
|
||||
macros:
|
||||
MAX_TOKENS: 4096
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
max_tokens: "${MAX_TOKENS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4096, config.Peers["openrouter"].Filters.SetParams["max_tokens"])
|
||||
})
|
||||
|
||||
t.Run("env macro in peer filters.setParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_RETENTION_POLICY", "deny")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
data_collection: "${env.TEST_RETENTION_POLICY}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "deny", config.Peers["openrouter"].Filters.SetParams["data_collection"])
|
||||
})
|
||||
|
||||
t.Run("env macro in peer filters.stripParams", func(t *testing.T) {
|
||||
t.Setenv("TEST_STRIP_PARAMS", "frequency_penalty, presence_penalty")
|
||||
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
stripParams: "${env.TEST_STRIP_PARAMS}"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "frequency_penalty, presence_penalty", config.Peers["openrouter"].Filters.StripParams)
|
||||
})
|
||||
|
||||
t.Run("unknown macro in peer apiKey fails", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: "${UNDEFINED_MACRO}"
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peers.openrouter.apiKey")
|
||||
assert.Contains(t, err.Error(), "unknown macro")
|
||||
})
|
||||
|
||||
t.Run("unknown macro in peer filters.setParams fails", func(t *testing.T) {
|
||||
content := `
|
||||
peers:
|
||||
openrouter:
|
||||
proxy: https://openrouter.ai/api
|
||||
models:
|
||||
- llama-3.1-8b
|
||||
filters:
|
||||
setParams:
|
||||
value: "${UNDEFINED_MACRO}"
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peers.openrouter.filters.setParams")
|
||||
assert.Contains(t, err.Error(), "unknown macro")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
|
||||
// These are protected to prevent breaking the proxy's ability to route requests correctly
|
||||
var ProtectedParams = []string{"model"}
|
||||
|
||||
// Filters contains filter settings for modifying request parameters
|
||||
// Used by both models and peers
|
||||
type Filters struct {
|
||||
// StripParams is a comma-separated list of parameters to remove from requests
|
||||
// The "model" parameter can never be removed
|
||||
StripParams string `yaml:"stripParams"`
|
||||
|
||||
// SetParams is a dictionary of parameters to set/override in requests
|
||||
// Protected params (like "model") cannot be set
|
||||
SetParams map[string]any `yaml:"setParams"`
|
||||
}
|
||||
|
||||
// SanitizedStripParams returns a sorted list of parameters to strip,
|
||||
// with duplicates, empty strings, and protected params removed
|
||||
func (f Filters) SanitizedStripParams() []string {
|
||||
if f.StripParams == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
// Skip protected params, empty strings, and duplicates
|
||||
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
if len(cleaned) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slices.Sort(cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// SanitizedSetParams returns a copy of SetParams with protected params removed
|
||||
// and keys sorted for consistent iteration order
|
||||
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
|
||||
if len(f.SetParams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any, len(f.SetParams))
|
||||
keys := make([]string, 0, len(f.SetParams))
|
||||
|
||||
for key, value := range f.SetParams {
|
||||
// Skip protected params
|
||||
if slices.Contains(ProtectedParams, key) {
|
||||
continue
|
||||
}
|
||||
result[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Sort keys for consistent ordering
|
||||
sort.Strings(keys)
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, keys
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilters_SanitizedStripParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripParams string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
stripParams: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single param",
|
||||
stripParams: "temperature",
|
||||
want: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
stripParams: "temperature, top_p, top_k",
|
||||
want: []string{"temperature", "top_k", "top_p"}, // sorted
|
||||
},
|
||||
{
|
||||
name: "model param filtered",
|
||||
stripParams: "model, temperature, top_p",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "only model param",
|
||||
stripParams: "model",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "duplicates removed",
|
||||
stripParams: "temperature, top_p, temperature",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace",
|
||||
stripParams: " temperature , top_p ",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "empty values filtered",
|
||||
stripParams: "temperature,,top_p,",
|
||||
want: []string{"temperature", "top_p"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{StripParams: tt.stripParams}
|
||||
got := f.SanitizedStripParams()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters_SanitizedSetParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setParams map[string]any
|
||||
wantParams map[string]any
|
||||
wantKeys []string
|
||||
}{
|
||||
{
|
||||
name: "empty setParams",
|
||||
setParams: nil,
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
setParams: map[string]any{},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "normal params",
|
||||
setParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
wantKeys: []string{"temperature", "top_p"},
|
||||
},
|
||||
{
|
||||
name: "protected model param filtered",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"temperature": 0.7,
|
||||
},
|
||||
wantKeys: []string{"temperature"},
|
||||
},
|
||||
{
|
||||
name: "only protected param",
|
||||
setParams: map[string]any{
|
||||
"model": "should-be-filtered",
|
||||
},
|
||||
wantParams: nil,
|
||||
wantKeys: nil,
|
||||
},
|
||||
{
|
||||
name: "complex nested values",
|
||||
setParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantParams: map[string]any{
|
||||
"provider": map[string]any{
|
||||
"data_collection": "deny",
|
||||
"allow_fallbacks": false,
|
||||
},
|
||||
"transforms": []string{"middle-out"},
|
||||
},
|
||||
wantKeys: []string{"provider", "transforms"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := Filters{SetParams: tt.setParams}
|
||||
gotParams, gotKeys := f.SanitizedSetParams()
|
||||
|
||||
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
|
||||
for i, key := range gotKeys {
|
||||
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
|
||||
}
|
||||
|
||||
if tt.wantParams == nil {
|
||||
assert.Nil(t, gotParams, "expected nil params")
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
|
||||
for key, wantValue := range tt.wantParams {
|
||||
gotValue, exists := gotParams[key]
|
||||
assert.True(t, exists, "missing key: %s", key)
|
||||
// Simple comparison for basic types
|
||||
switch v := wantValue.(type) {
|
||||
case string, int, float64, bool:
|
||||
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedParams(t *testing.T) {
|
||||
// Verify that "model" is protected
|
||||
assert.Contains(t, ProtectedParams, "model")
|
||||
}
|
||||
@@ -3,8 +3,6 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
@@ -74,16 +72,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||
return SanitizeCommand(m.Cmd)
|
||||
}
|
||||
|
||||
// ModelFilters see issue #174
|
||||
// ModelFilters embeds Filters and adds legacy support for strip_params field
|
||||
// See issue #174
|
||||
type ModelFilters struct {
|
||||
StripParams string `yaml:"stripParams"`
|
||||
Filters `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawModelFilters ModelFilters
|
||||
defaults := rawModelFilters{
|
||||
StripParams: "",
|
||||
}
|
||||
defaults := rawModelFilters{}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
@@ -104,25 +101,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
|
||||
// Returns ([]string, error) to match existing API
|
||||
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
|
||||
if f.StripParams == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := strings.Split(f.StripParams, ",")
|
||||
cleaned := make([]string, 0, len(params))
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, param := range params {
|
||||
trimmed := strings.TrimSpace(param)
|
||||
if trimmed == "model" || trimmed == "" || seen[trimmed] {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = true
|
||||
cleaned = append(cleaned, trimmed)
|
||||
}
|
||||
|
||||
// sort cleaned
|
||||
slices.Sort(cleaned)
|
||||
return cleaned, nil
|
||||
return f.Filters.SanitizedStripParams(), nil
|
||||
}
|
||||
|
||||
@@ -72,3 +72,35 @@ models:
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
filters:
|
||||
stripParams: "top_k"
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
top_p: 0.9
|
||||
stop:
|
||||
- "<|end|>"
|
||||
- "<|stop|>"
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelConfig := config.Models["model1"]
|
||||
|
||||
// Check stripParams
|
||||
stripParams, err := modelConfig.Filters.SanitizedStripParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"top_k"}, stripParams)
|
||||
|
||||
// Check setParams
|
||||
setParams, keys := modelConfig.Filters.SanitizedSetParams()
|
||||
assert.NotNil(t, setParams)
|
||||
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
|
||||
assert.Equal(t, 0.7, setParams["temperature"])
|
||||
assert.Equal(t, 0.9, setParams["top_p"])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type PeerDictionaryConfig map[string]PeerConfig
|
||||
type PeerConfig struct {
|
||||
Proxy string `yaml:"proxy"`
|
||||
ProxyURL *url.URL `yaml:"-"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Models []string `yaml:"models"`
|
||||
Filters Filters `yaml:"filters"`
|
||||
}
|
||||
|
||||
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawPeerConfig PeerConfig
|
||||
defaults := rawPeerConfig{
|
||||
Proxy: "",
|
||||
ApiKey: "",
|
||||
Models: []string{},
|
||||
Filters: Filters{},
|
||||
}
|
||||
|
||||
if err := unmarshal(&defaults); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate proxy is not empty
|
||||
if defaults.Proxy == "" {
|
||||
return fmt.Errorf("proxy is required")
|
||||
}
|
||||
|
||||
// Validate proxy is a valid URL and store the parsed value
|
||||
parsedURL, err := url.Parse(defaults.Proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
|
||||
}
|
||||
defaults.ProxyURL = parsedURL
|
||||
|
||||
// Validate models is not empty
|
||||
if len(defaults.Models) == 0 {
|
||||
return fmt.Errorf("peer models can not be empty")
|
||||
}
|
||||
|
||||
*c = PeerConfig(defaults)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
yaml: `
|
||||
proxy: http://192.168.1.23
|
||||
models:
|
||||
- model_a
|
||||
- model_b
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid config with apiKey",
|
||||
yaml: `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test-key
|
||||
models:
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "missing proxy",
|
||||
yaml: `
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "empty proxy",
|
||||
yaml: `
|
||||
proxy: ""
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "proxy is required",
|
||||
},
|
||||
{
|
||||
name: "invalid proxy URL",
|
||||
yaml: `
|
||||
proxy: "://invalid"
|
||||
models:
|
||||
- model_a
|
||||
`,
|
||||
wantErr: "invalid peer proxy URL",
|
||||
},
|
||||
{
|
||||
name: "missing models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty models",
|
||||
yaml: `
|
||||
proxy: http://localhost:8080
|
||||
models: []
|
||||
`,
|
||||
wantErr: "peer models can not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(tt.yaml), &config)
|
||||
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.wantErr)
|
||||
} else if !contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_ProxyURL(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: http://192.168.1.23:8080/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.ProxyURL == nil {
|
||||
t.Fatal("ProxyURL should not be nil")
|
||||
}
|
||||
|
||||
if config.ProxyURL.Host != "192.168.1.23:8080" {
|
||||
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Scheme != "http" {
|
||||
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
|
||||
}
|
||||
|
||||
if config.ProxyURL.Path != "/api" {
|
||||
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchSubstring(s, substr)
|
||||
}
|
||||
|
||||
func searchSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
setParams:
|
||||
temperature: 0.7
|
||||
provider:
|
||||
data_collection: deny
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
|
||||
if config.Filters.SetParams["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
|
||||
}
|
||||
|
||||
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("provider should be a map")
|
||||
}
|
||||
if provider["data_collection"] != "deny" {
|
||||
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerConfig_WithBothFilters(t *testing.T) {
|
||||
yamlData := `
|
||||
proxy: https://openrouter.ai/api
|
||||
apiKey: sk-test
|
||||
models:
|
||||
- model_a
|
||||
filters:
|
||||
stripParams: "temperature, top_p"
|
||||
setParams:
|
||||
max_tokens: 1000
|
||||
`
|
||||
var config PeerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &config)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check stripParams
|
||||
stripParams := config.Filters.SanitizedStripParams()
|
||||
if len(stripParams) != 2 {
|
||||
t.Errorf("expected 2 strip params, got %d", len(stripParams))
|
||||
}
|
||||
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
|
||||
t.Errorf("unexpected strip params: %v", stripParams)
|
||||
}
|
||||
|
||||
// Check setParams
|
||||
if config.Filters.SetParams == nil {
|
||||
t.Fatal("Filters.SetParams should not be nil")
|
||||
}
|
||||
if config.Filters.SetParams["max_tokens"] != 1000 {
|
||||
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
|
||||
}
|
||||
}
|
||||
+78
-10
@@ -2,6 +2,8 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -96,6 +98,12 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||
) error {
|
||||
recorder := newBodyCopier(writer)
|
||||
|
||||
// Filter Accept-Encoding to only include encodings we can decompress for metrics
|
||||
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
|
||||
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
|
||||
}
|
||||
|
||||
if err := next(modelID, recorder, request); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -108,17 +116,36 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize default metrics - these will always be recorded
|
||||
tm := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: modelID,
|
||||
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
mp.logger.Warn("metrics skipped, empty body")
|
||||
mp.logger.Warn("metrics: empty body, recording minimal metrics")
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
||||
} else {
|
||||
// Decompress if needed
|
||||
if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
|
||||
var err error
|
||||
body, err = decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
tm = parsed
|
||||
}
|
||||
} else {
|
||||
if gjson.ValidBytes(body) {
|
||||
@@ -127,18 +154,18 @@ func (mp *metricsMonitor) wrapHandler(
|
||||
timings := parsed.Get("timings")
|
||||
|
||||
if usage.Exists() || timings.Exists() {
|
||||
if tm, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
||||
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
|
||||
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
|
||||
} else {
|
||||
mp.addMetrics(tm)
|
||||
tm = parsedMetrics
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
||||
mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
mp.addMetrics(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -251,6 +278,25 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decompressBody decompresses the body based on Content-Encoding header
|
||||
func decompressBody(body []byte, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(encoding)) {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
case "deflate":
|
||||
reader := flate.NewReader(bytes.NewReader(body))
|
||||
defer reader.Close()
|
||||
return io.ReadAll(reader)
|
||||
default:
|
||||
return body, nil // Return as-is for unknown/no encoding
|
||||
}
|
||||
}
|
||||
|
||||
// responseBodyCopier records the response body and writes to the original response writer
|
||||
// while also capturing it in a buffer for later processing
|
||||
type responseBodyCopier struct {
|
||||
@@ -289,3 +335,25 @@ func (w *responseBodyCopier) Header() http.Header {
|
||||
func (w *responseBodyCopier) StartTime() time.Time {
|
||||
return w.start
|
||||
}
|
||||
|
||||
// filterAcceptEncoding filters the Accept-Encoding header to only include
|
||||
// encodings we can decompress (gzip, deflate). This respects the client's
|
||||
// preferences while ensuring we can parse response bodies for metrics.
|
||||
func filterAcceptEncoding(acceptEncoding string) string {
|
||||
if acceptEncoding == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
supported := map[string]bool{"gzip": true, "deflate": true}
|
||||
var filtered []string
|
||||
|
||||
for _, part := range strings.Split(acceptEncoding, ",") {
|
||||
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
|
||||
encoding := strings.TrimSpace(strings.Split(part, ";")[0])
|
||||
if supported[strings.ToLower(encoding)] {
|
||||
filtered = append(filtered, strings.TrimSpace(part))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filtered, ", ")
|
||||
}
|
||||
|
||||
+154
-13
@@ -1,6 +1,9 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -291,7 +294,7 @@ data: [DONE]
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("empty response body does not record metrics", func(t *testing.T) {
|
||||
t.Run("empty response body records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -307,10 +310,13 @@ data: [DONE]
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
|
||||
t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -328,7 +334,10 @@ data: [DONE]
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||
@@ -350,7 +359,7 @@ data: [DONE]
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
})
|
||||
|
||||
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
|
||||
t.Run("response without usage or timings records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `{"result": "ok"}`
|
||||
@@ -367,10 +376,13 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -598,7 +610,7 @@ data: [DONE]
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
|
||||
t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `data: not json
|
||||
@@ -619,13 +631,16 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("handles empty streaming response", func(t *testing.T) {
|
||||
t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := ``
|
||||
@@ -642,11 +657,13 @@ data: [DONE]
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
// Empty body should not trigger WrapHandler processing
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 0, len(metrics))
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -691,3 +708,127 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||
mm.addMetrics(metric)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
||||
t.Run("gzip encoded response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
|
||||
|
||||
// Compress with gzip
|
||||
var buf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&buf)
|
||||
gzWriter.Write([]byte(responseBody))
|
||||
gzWriter.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(compressedBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("deflate encoded response", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}`
|
||||
|
||||
// Compress with deflate
|
||||
var buf bytes.Buffer
|
||||
flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression)
|
||||
flateWriter.Write([]byte(responseBody))
|
||||
flateWriter.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "deflate")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(compressedBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 200, metrics[0].InputTokens)
|
||||
assert.Equal(t, 75, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
// Invalid compressed data
|
||||
invalidData := []byte("this is not gzip data")
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(invalidData)
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err) // Should not return error, just log warning
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "test-model", metrics[0].Model)
|
||||
assert.Equal(t, 0, metrics[0].InputTokens)
|
||||
assert.Equal(t, 0, metrics[0].OutputTokens)
|
||||
})
|
||||
|
||||
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
|
||||
mm := newMetricsMonitor(testLogger, 10)
|
||||
|
||||
responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}`
|
||||
|
||||
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "unknown-encoding")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(responseBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
metrics := mm.getMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, 300, metrics[0].InputTokens)
|
||||
assert.Equal(t, 100, metrics[0].OutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
)
|
||||
|
||||
type peerProxyMember struct {
|
||||
peerID string
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type PeerProxy struct {
|
||||
peers config.PeerDictionaryConfig
|
||||
proxyMap map[string]*peerProxyMember
|
||||
}
|
||||
|
||||
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
|
||||
proxyMap := make(map[string]*peerProxyMember)
|
||||
|
||||
// Sort peer IDs for consistent iteration order
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
for peerID := range peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
sort.Strings(peerIDs)
|
||||
|
||||
// Create a shared transport with reasonable timeouts for peer connections
|
||||
// these can be tuned with feedback later
|
||||
peerTransport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second, // Connection timeout
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 60 * time.Second, // Time to wait for response headers
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
peer := peers[peerID]
|
||||
// Create reverse proxy for this peer
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
|
||||
reverseProxy.Transport = peerTransport
|
||||
|
||||
// Wrap Director to set Host header for remote hosts (not localhost)
|
||||
originalDirector := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Ensure Host header matches target URL for remote proxying
|
||||
req.Host = req.URL.Host
|
||||
}
|
||||
|
||||
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||
resp.Header.Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
|
||||
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
||||
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
||||
}
|
||||
http.Error(w, errMsg, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
pp := &peerProxyMember{
|
||||
peerID: peerID,
|
||||
reverseProxy: reverseProxy,
|
||||
apiKey: peer.ApiKey,
|
||||
}
|
||||
|
||||
// Map each model to this peer's proxy
|
||||
for _, modelID := range peer.Models {
|
||||
if _, found := proxyMap[modelID]; found {
|
||||
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
||||
continue
|
||||
}
|
||||
proxyMap[modelID] = pp
|
||||
}
|
||||
}
|
||||
|
||||
return &PeerProxy{
|
||||
peers: peers,
|
||||
proxyMap: proxyMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PeerProxy) HasPeerModel(modelID string) bool {
|
||||
_, found := p.proxyMap[modelID]
|
||||
return found
|
||||
}
|
||||
|
||||
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
|
||||
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
|
||||
pp, found := p.proxyMap[modelID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
// Get the peer config using the peerID
|
||||
peer, found := p.peers[pp.peerID]
|
||||
if !found {
|
||||
return config.Filters{}
|
||||
}
|
||||
return peer.Filters
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
|
||||
return p.peers
|
||||
}
|
||||
|
||||
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
|
||||
pp, found := p.proxyMap[model_id]
|
||||
if !found {
|
||||
return fmt.Errorf("no peer proxy found for model %s", model_id)
|
||||
}
|
||||
|
||||
// Inject API key if configured for this peer
|
||||
if pp.apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
||||
request.Header.Set("x-api-key", pp.apiKey)
|
||||
}
|
||||
|
||||
pp.reverseProxy.ServeHTTP(writer, request)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pm)
|
||||
assert.Empty(t, pm.proxyMap)
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_SinglePeer(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "test-key",
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 2)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.False(t, pm.HasPeerModel("model-c"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"model-a", "model-b"},
|
||||
},
|
||||
"peer2": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"model-c", "model-d"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pm.proxyMap, 4)
|
||||
assert.True(t, pm.HasPeerModel("model-a"))
|
||||
assert.True(t, pm.HasPeerModel("model-b"))
|
||||
assert.True(t, pm.HasPeerModel("model-c"))
|
||||
assert.True(t, pm.HasPeerModel("model-d"))
|
||||
}
|
||||
|
||||
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
|
||||
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
|
||||
// should be mapped, and a warning should be logged
|
||||
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
|
||||
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"alpha-peer": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL1,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
"beta-peer": config.PeerConfig{
|
||||
Proxy: "http://peer2.example.com:8080",
|
||||
ProxyURL: proxyURL2,
|
||||
Models: []string{"duplicate-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
// Should only have one entry for the duplicate model
|
||||
assert.Len(t, pm.proxyMap, 1)
|
||||
assert.True(t, pm.HasPeerModel("duplicate-model"))
|
||||
}
|
||||
|
||||
func TestHasPeerModel(t *testing.T) {
|
||||
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: "http://peer1.example.com:8080",
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"existing-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, pm.HasPeerModel("existing-model"))
|
||||
assert.False(t, pm.HasPeerModel("non-existing-model"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_ModelNotFound(t *testing.T) {
|
||||
peers := config.PeerDictionaryConfig{}
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("non-existing-model", w, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
|
||||
}
|
||||
|
||||
func TestProxyRequest_Success(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response from peer"))
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "response from peer", w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "secret-api-key",
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_NoApiKey(t *testing.T) {
|
||||
// Create a test server that checks for the Authorization header
|
||||
var receivedAuthHeader string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
ApiKey: "", // No API key
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, receivedAuthHeader)
|
||||
}
|
||||
|
||||
func TestProxyRequest_HostHeaderSet(t *testing.T) {
|
||||
// Create a test server that checks the Host header
|
||||
var receivedHost string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The Host header should be set to the target URL's host
|
||||
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
|
||||
}
|
||||
|
||||
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
|
||||
// Create a test server that returns SSE content type
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
proxyURL, _ := url.Parse(testServer.URL)
|
||||
peers := config.PeerDictionaryConfig{
|
||||
"peer1": config.PeerConfig{
|
||||
Proxy: testServer.URL,
|
||||
ProxyURL: proxyURL,
|
||||
Models: []string{"test-model"},
|
||||
},
|
||||
}
|
||||
|
||||
pm, err := NewPeerProxy(peers, testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err = pm.ProxyRequest("test-model", w, req)
|
||||
assert.NoError(t, err)
|
||||
// The X-Accel-Buffering header should be set to "no" for SSE
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
}
|
||||
@@ -395,6 +395,10 @@ func TestProcess_StopImmediately(t *testing.T) {
|
||||
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
|
||||
// the upstream command
|
||||
func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping SIGTERM test on Windows ")
|
||||
}
|
||||
|
||||
@@ -49,6 +49,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
|
||||
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
||||
// and multiple requests are made in parallel, only one process is running at a time.
|
||||
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
|
||||
+268
-111
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
@@ -50,6 +51,9 @@ type ProxyManager struct {
|
||||
buildDate string
|
||||
commit string
|
||||
version string
|
||||
|
||||
// peer proxy see: #296, #433
|
||||
peerProxy *PeerProxy
|
||||
}
|
||||
|
||||
func New(proxyConfig config.Config) *ProxyManager {
|
||||
@@ -133,6 +137,12 @@ func New(proxyConfig config.Config) *ProxyManager {
|
||||
maxMetrics = proxyConfig.MetricsMaxInMemory
|
||||
}
|
||||
|
||||
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
|
||||
if err != nil {
|
||||
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
|
||||
peerProxy = nil
|
||||
}
|
||||
|
||||
pm := &ProxyManager{
|
||||
config: proxyConfig,
|
||||
ginEngine: gin.New(),
|
||||
@@ -151,6 +161,8 @@ func New(proxyConfig config.Config) *ProxyManager {
|
||||
buildDate: "unknown",
|
||||
commit: "abcd1234",
|
||||
version: "0",
|
||||
|
||||
peerProxy: peerProxy,
|
||||
}
|
||||
|
||||
// create the process groups
|
||||
@@ -166,22 +178,29 @@ func New(proxyConfig config.Config) *ProxyManager {
|
||||
// do it in the background, don't block startup -- not sure if good idea yet
|
||||
go func() {
|
||||
discardWriter := &DiscardWriter{}
|
||||
for _, realModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||
proxyLogger.Infof("Preloading model: %s", realModelName)
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
|
||||
modelID, ok := proxyConfig.RealModelName(preloadModelName)
|
||||
|
||||
if !ok {
|
||||
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyLogger.Infof("Preloading model: %s", modelID)
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
|
||||
if err != nil {
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: false,
|
||||
})
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
|
||||
proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
|
||||
continue
|
||||
} else {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
processGroup.ProxyRequest(realModelName, discardWriter, req)
|
||||
processGroup.ProxyRequest(modelID, discardWriter, req)
|
||||
event.Emit(ModelPreloadedEvent{
|
||||
ModelName: realModelName,
|
||||
ModelName: modelID,
|
||||
Success: true,
|
||||
})
|
||||
}
|
||||
@@ -256,37 +275,44 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
})
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler)
|
||||
// Protected routes use pm.apiKeyAuth() middleware
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
|
||||
pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
// Support anthropic count_tokens API (Also added in the above PR)
|
||||
pm.ginEngine.POST("/v1/messages/count_tokens", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
|
||||
// Support embeddings and reranking
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /reranking endpoint + aliases
|
||||
pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /infill endpoint for code infilling
|
||||
pm.ginEngine.POST("/infill", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
|
||||
// llama-server's /completion endpoint
|
||||
pm.ginEngine.POST("/completion", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
|
||||
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.proxyInferenceHandler)
|
||||
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
|
||||
|
||||
// in proxymanager_loghandlers.go
|
||||
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
|
||||
pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||
pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
|
||||
|
||||
/**
|
||||
* User Interface Endpoints
|
||||
@@ -298,9 +324,9 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.proxyToUpstream)
|
||||
pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
@@ -398,16 +424,10 @@ func (pm *ProxyManager) Shutdown() {
|
||||
pm.shutdownCancel()
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
||||
// de-alias the real model name and get a real one
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
||||
return nil, fmt.Errorf("could not find process group for model %s", realModelName)
|
||||
}
|
||||
|
||||
if processGroup.exclusive {
|
||||
@@ -419,54 +439,71 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
|
||||
}
|
||||
}
|
||||
|
||||
return processGroup, realModelName, nil
|
||||
return processGroup, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
data := make([]gin.H, 0, len(pm.config.Models))
|
||||
createdTime := time.Now().Unix()
|
||||
|
||||
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
for id, modelConfig := range pm.config.Models {
|
||||
if modelConfig.Unlisted {
|
||||
continue
|
||||
}
|
||||
|
||||
newRecord := func(modelId string) gin.H {
|
||||
record := gin.H{
|
||||
"id": modelId,
|
||||
"object": "model",
|
||||
"created": createdTime,
|
||||
"owned_by": "llama-swap",
|
||||
}
|
||||
|
||||
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
||||
record["name"] = name
|
||||
}
|
||||
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
||||
record["description"] = desc
|
||||
}
|
||||
|
||||
// Add metadata if present
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
record["meta"] = gin.H{
|
||||
"llamaswap": modelConfig.Metadata,
|
||||
}
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
data = append(data, newRecord(id))
|
||||
data = append(data, newRecord(id, modelConfig))
|
||||
|
||||
// Include aliases
|
||||
if pm.config.IncludeAliasesInList {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if alias := strings.TrimSpace(alias); alias != "" {
|
||||
data = append(data, newRecord(alias))
|
||||
data = append(data, newRecord(alias, modelConfig))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
// add peer models
|
||||
for _, modelID := range peer.Models {
|
||||
// Skip unlisted models if not showing them
|
||||
record := newRecord(modelID, config.ModelConfig{
|
||||
Name: fmt.Sprintf("%s: %s", peerID, modelID),
|
||||
Metadata: map[string]any{
|
||||
"peerID": peerID,
|
||||
},
|
||||
})
|
||||
|
||||
data = append(data, record)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by the "id" key
|
||||
sort.Slice(data, func(i, j int) bool {
|
||||
si, _ := data[i]["id"].(string)
|
||||
@@ -505,8 +542,8 @@ func (pm *ProxyManager) findModelInPath(path string) (searchName string, realNam
|
||||
searchModelName = searchModelName + "/" + part
|
||||
}
|
||||
|
||||
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
return searchModelName, real, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
if modelID, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -516,23 +553,22 @@ func (pm *ProxyManager) findModelInPath(path string) (searchName string, realNam
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
searchModelName, modelName, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
|
||||
|
||||
if !modelFound {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is exactly a model name with no additional path
|
||||
// and doesn't end with a trailing slash
|
||||
// Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
|
||||
// This ensures relative URLs in upstream responses resolve correctly and
|
||||
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
|
||||
// HTTP method (301 would downgrade to GET).
|
||||
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
|
||||
// Build new URL with query parameters preserved
|
||||
newPath := "/upstream/" + searchModelName + "/"
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
newPath += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// Use 308 for non-GET/HEAD requests to preserve method
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||
c.Redirect(http.StatusMovedPermanently, newPath)
|
||||
} else {
|
||||
@@ -541,7 +577,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -553,15 +589,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
|
||||
// attempt to record metrics if it is a POST request
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
||||
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -580,41 +616,90 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
// Look for a matching local model first
|
||||
var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if found {
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
if useModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// issue #174 strip parameters from the JSON body
|
||||
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
|
||||
if err != nil { // just log it and continue
|
||||
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
|
||||
} else {
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// issue #453 set/override parameters in the JSON body
|
||||
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
|
||||
nextHandler = processGroup.ProxyRequest
|
||||
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
|
||||
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
|
||||
modelID = requestedModel
|
||||
|
||||
// issue #453 apply filters for peer requests
|
||||
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
|
||||
|
||||
// Apply stripParams - remove specified parameters from request
|
||||
stripParams := peerFilters.SanitizedStripParams()
|
||||
for _, param := range stripParams {
|
||||
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
|
||||
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Apply setParams - set/override specified parameters in request
|
||||
setParams, setParamKeys := peerFilters.SanitizedSetParams()
|
||||
for _, key := range setParamKeys {
|
||||
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
nextHandler = pm.peerProxy.ProxyRequest
|
||||
}
|
||||
|
||||
if nextHandler == nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
@@ -627,19 +712,19 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||
if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -659,7 +744,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
modelID, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, err := pm.swapProcessGroup(modelID)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
@@ -677,7 +768,7 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" {
|
||||
// # issue #69 allow custom model names to be sent to upstream
|
||||
useModelName := pm.config.Models[realModelName].UseModelName
|
||||
useModelName := pm.config.Models[modelID].UseModelName
|
||||
|
||||
if useModelName != "" {
|
||||
fieldValue = useModelName
|
||||
@@ -748,9 +839,9 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||
|
||||
// Use the modified request for proxying
|
||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||
if err := processGroup.ProxyRequest(modelID, c.Writer, modifiedReq); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, modelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -765,6 +856,67 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
|
||||
}
|
||||
}
|
||||
|
||||
// apiKeyAuth returns a middleware that validates API keys if configured.
|
||||
// Returns a pass-through handler if no API keys are configured.
|
||||
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
|
||||
if len(pm.config.RequiredAPIKeys) == 0 {
|
||||
return func(c *gin.Context) { c.Next() }
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
xApiKey := c.GetHeader("x-api-key")
|
||||
|
||||
var bearerKey string
|
||||
var basicKey string
|
||||
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
bearerKey = strings.TrimPrefix(auth, "Bearer ")
|
||||
} else if strings.HasPrefix(auth, "Basic ") {
|
||||
// Basic Auth: base64(username:password), password is the API key
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) == 2 {
|
||||
basicKey = parts[1] // password is the API key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use first key found: Basic, then Bearer, then x-api-key
|
||||
var providedKey string
|
||||
if basicKey != "" {
|
||||
providedKey = basicKey
|
||||
} else if bearerKey != "" {
|
||||
providedKey = bearerKey
|
||||
} else {
|
||||
providedKey = xApiKey
|
||||
}
|
||||
|
||||
// Validate key
|
||||
valid := false
|
||||
for _, key := range pm.config.RequiredAPIKeys {
|
||||
if providedKey == key {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
|
||||
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Strip auth headers to prevent leakage to upstream
|
||||
c.Request.Header.Del("Authorization")
|
||||
c.Request.Header.Del("x-api-key")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
||||
pm.StopProcesses(StopImmediately)
|
||||
c.String(http.StatusOK, "OK")
|
||||
@@ -778,8 +930,13 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
||||
for _, process := range processGroup.processes {
|
||||
if process.CurrentState() == StateReady {
|
||||
runningProcesses = append(runningProcesses, gin.H{
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"model": process.ID,
|
||||
"state": process.state,
|
||||
"cmd": process.config.Cmd,
|
||||
"proxy": process.config.Proxy,
|
||||
"ttl": process.config.UnloadAfter,
|
||||
"name": process.config.Name,
|
||||
"description": process.config.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,11 +18,13 @@ type Model struct {
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Unlisted bool `json:"unlisted"`
|
||||
PeerID string `json:"peerID"`
|
||||
}
|
||||
|
||||
func addApiHandlers(pm *ProxyManager) {
|
||||
// Add API endpoints for React to consume
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
// Protected with API key authentication
|
||||
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
@@ -82,6 +84,18 @@ func (pm *ProxyManager) getModelStatus() []Model {
|
||||
})
|
||||
}
|
||||
|
||||
// Iterate over the peer models
|
||||
if pm.peerProxy != nil {
|
||||
for peerID, peer := range pm.peerProxy.ListPeers() {
|
||||
for _, modelID := range peer.Models {
|
||||
models = append(models, Model{
|
||||
Id: modelID,
|
||||
PeerID: peerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
|
||||
+394
-14
@@ -3,6 +3,7 @@ package proxy
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
||||
return r.closeChannel
|
||||
}
|
||||
|
||||
func (r *TestResponseRecorder) closeClient() {
|
||||
r.closeChannel <- true
|
||||
}
|
||||
|
||||
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||
return &TestResponseRecorder{
|
||||
httptest.NewRecorder(),
|
||||
@@ -223,17 +220,23 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
model2Config.Name = " " // empty whitespace only strings will get ignored
|
||||
model2Config.Description = " "
|
||||
|
||||
config := config.Config{
|
||||
cfg := config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": model1Config,
|
||||
"model2": model2Config,
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
Peers: map[string]config.PeerConfig{
|
||||
"peer1": {
|
||||
Proxy: "http://peer1:8080",
|
||||
Models: []string{"peer-model-a", "peer-model-b"},
|
||||
},
|
||||
},
|
||||
LogLevel: "error",
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
proxy := New(cfg)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
@@ -258,14 +261,16 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Check the number of models returned
|
||||
assert.Len(t, response.Data, 3)
|
||||
// Check the number of models returned (3 local + 2 peer models)
|
||||
assert.Len(t, response.Data, 5)
|
||||
|
||||
// Check the details of each model
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
"peer-model-a": {},
|
||||
"peer-model-b": {},
|
||||
}
|
||||
|
||||
// make all models
|
||||
@@ -296,6 +301,19 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
description, ok := model["description"].(string)
|
||||
assert.True(t, ok, "description should be a string")
|
||||
assert.Equal(t, "Model 1 description is used for testing", description)
|
||||
} else if modelID == "peer-model-a" || modelID == "peer-model-b" {
|
||||
// Peer models should have meta.llamaswap.peerID
|
||||
meta, exists := model["meta"]
|
||||
assert.True(t, exists, "peer model should have meta field")
|
||||
metaMap, ok := meta.(map[string]interface{})
|
||||
assert.True(t, ok, "meta should be a map")
|
||||
llamaswap, exists := metaMap["llamaswap"]
|
||||
assert.True(t, exists, "meta should have llamaswap field")
|
||||
llamaswapMap, ok := llamaswap.(map[string]interface{})
|
||||
assert.True(t, ok, "llamaswap should be a map")
|
||||
peerID, exists := llamaswapMap["peerID"]
|
||||
assert.True(t, exists, "llamaswap should have peerID field")
|
||||
assert.Equal(t, "peer1", peerID)
|
||||
} else {
|
||||
_, exists := model["name"]
|
||||
assert.False(t, exists, "unexpected name field for model: %s", modelID)
|
||||
@@ -502,6 +520,10 @@ func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxyManager_Shutdown(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test")
|
||||
}
|
||||
|
||||
// make broken model configurations
|
||||
model1Config := getTestSimpleResponderConfigPort("model1", 9991)
|
||||
model1Config.Proxy = "http://localhost:10001/"
|
||||
@@ -650,8 +672,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Define a helper struct to parse the JSON response.
|
||||
type RunningResponse struct {
|
||||
Running []struct {
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Model string `json:"model"`
|
||||
State string `json:"state"`
|
||||
Cmd string `json:"cmd"`
|
||||
Proxy string `json:"proxy"`
|
||||
TTL int `json:"ttl"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"running"`
|
||||
}
|
||||
|
||||
@@ -699,6 +726,11 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
|
||||
// Is the model loaded?
|
||||
assert.Equal(t, "ready", response.Running[0].State)
|
||||
|
||||
// Verify extended fields are present
|
||||
assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated")
|
||||
assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated")
|
||||
assert.Equal(t, 0, response.Running[0].TTL, "ttl should default to 0")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -944,7 +976,9 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
modelConfig := getTestSimpleResponderConfig("model1")
|
||||
modelConfig.Filters = config.ModelFilters{
|
||||
StripParams: "temperature, model, stream",
|
||||
Filters: config.Filters{
|
||||
StripParams: "temperature, model, stream",
|
||||
},
|
||||
}
|
||||
|
||||
config := config.AddDefaultGroupToConfig(config.Config{
|
||||
@@ -1187,3 +1221,349 @@ func TestProxyManager_ApiGetVersion(t *testing.T) {
|
||||
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_APIKeyAuth(t *testing.T) {
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
t.Run("valid key in x-api-key header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("valid key in Authorization Bearer header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer valid-key-2")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("both headers with matching keys", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
req.Header.Set("Authorization", "Bearer valid-key-1")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid key returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "invalid-key")
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||
})
|
||||
|
||||
t.Run("missing key returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
})
|
||||
|
||||
t.Run("valid key in Basic Auth header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
// Basic Auth: base64("anyuser:valid-key-1")
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "unauthorized")
|
||||
})
|
||||
|
||||
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
req.Header.Set("x-api-key", "valid-key-1")
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
|
||||
// Config without RequiredAPIKeys - auth should be disabled
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
t.Run("requests pass without API key when not configured", func(t *testing.T) {
|
||||
reqBody := `{"model":"model1"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration
|
||||
// in proxyInferenceHandler for issue #433
|
||||
func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) {
|
||||
t.Run("requests to peer models are proxied", func(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
// Create config with peers but no local model for "peer-model"
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "from-peer")
|
||||
})
|
||||
|
||||
t.Run("local models take precedence over peer models", func(t *testing.T) {
|
||||
// Create a test server to act as the peer - should NOT be called
|
||||
peerCalled := false
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
peerCalled = true
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"from-peer"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
// Create config where "shared-model" exists both locally and on peer
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- shared-model
|
||||
models:
|
||||
shared-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-response
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"shared-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "local-response")
|
||||
assert.False(t, peerCalled, "peer should not be called when local model exists")
|
||||
})
|
||||
|
||||
t.Run("unknown model returns error", func(t *testing.T) {
|
||||
// Create a test server to act as the peer
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"unknown-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
||||
})
|
||||
|
||||
t.Run("peer API key is injected into request", func(t *testing.T) {
|
||||
var receivedAuthHeader string
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuthHeader = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"response":"ok"}`))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
apiKey: secret-peer-key
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader)
|
||||
})
|
||||
|
||||
t.Run("no peers configured - unknown model returns error", func(t *testing.T) {
|
||||
testConfig := config.AddDefaultGroupToConfig(config.Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]config.ModelConfig{
|
||||
"local-model": getTestSimpleResponderConfig("local-model"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
// peerProxy exists but has no peer models configured
|
||||
assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model"))
|
||||
|
||||
reqBody := `{"model":"unknown-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
|
||||
})
|
||||
|
||||
t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) {
|
||||
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("data: test\n\n"))
|
||||
}))
|
||||
defer peerServer.Close()
|
||||
|
||||
configStr := fmt.Sprintf(`
|
||||
logLevel: error
|
||||
peers:
|
||||
test-peer:
|
||||
proxy: %s
|
||||
models:
|
||||
- peer-model
|
||||
models:
|
||||
local-model:
|
||||
cmd: %s -port ${PORT} -silent -respond local-model
|
||||
`, peerServer.URL, getSimpleResponderPath())
|
||||
|
||||
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
|
||||
assert.NoError(t, err)
|
||||
|
||||
proxy := New(testConfig)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
reqBody := `{"model":"peer-model"}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := CreateTestResponseRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
|
||||
})
|
||||
}
|
||||
|
||||
Generated
+98
-39
@@ -12,7 +12,7 @@
|
||||
"react-dom": "^19.1.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-resizable-panels": "^3.0.4",
|
||||
"react-router-dom": "^7.6.2"
|
||||
"react-router-dom": "^7.12.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.25.0",
|
||||
@@ -75,6 +75,7 @@
|
||||
"integrity": "sha512-bXYxrXFubeYdvB0NhD/NBB3Qi6aZeV20GOWVI47t2dkecCEoneR4NPVcb7abpXDEvejgrUfFtG6vG/zxAKmg+g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@ampproject/remapping": "^2.2.0",
|
||||
"@babel/code-frame": "^7.27.1",
|
||||
@@ -1593,6 +1594,66 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/core": {
|
||||
"version": "1.4.3",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/wasi-threads": "1.0.2",
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/runtime": {
|
||||
"version": "1.4.3",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/wasi-threads": {
|
||||
"version": "1.0.2",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@napi-rs/wasm-runtime": {
|
||||
"version": "0.2.10",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/core": "^1.4.3",
|
||||
"@emnapi/runtime": "^1.4.3",
|
||||
"@tybys/wasm-util": "^0.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@tybys/wasm-util": {
|
||||
"version": "0.9.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/tslib": {
|
||||
"version": "2.8.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "0BSD",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-win32-arm64-msvc": {
|
||||
"version": "4.1.8",
|
||||
"resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.8.tgz",
|
||||
@@ -1707,6 +1768,7 @@
|
||||
"integrity": "sha512-JeG0rEWak0N6Itr6QUx+X60uQmN+5t3j9r/OVDtWzFXKaj6kD1BwJzOksD0FF6iWxZlbE1kB0q9vtnU2ekqa1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"csstype": "^3.0.2"
|
||||
}
|
||||
@@ -1767,6 +1829,7 @@
|
||||
"integrity": "sha512-qwxv6dq682yVvgKKp2qWwLgRbscDAYktPptK4JPojCwwi3R9cwrvIxS4lvBpzmcqzR4bdn54Z0IG1uHFskW4dA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "8.33.1",
|
||||
"@typescript-eslint/types": "8.33.1",
|
||||
@@ -2018,6 +2081,7 @@
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -2126,6 +2190,7 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"caniuse-lite": "^1.0.30001718",
|
||||
"electron-to-chromium": "^1.5.160",
|
||||
@@ -2232,12 +2297,16 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/cookie": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/cookie/-/cookie-1.0.2.tgz",
|
||||
"integrity": "sha512-9Kr/j4O16ISv8zBBhJoi4bXOYNTkFLOqSL3UDB0njXxCXNezjeyVrJyGOWtgfs/q2km1gwBcfH8q1yEGoMYunA==",
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/cookie/-/cookie-1.1.1.tgz",
|
||||
"integrity": "sha512-ei8Aos7ja0weRpFzJnEA9UHJ/7XQmqglbRwnf2ATjcB9Wq874VKH9kfjjirM6UhU2/E5fFYadylyhFldcqSidQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/express"
|
||||
}
|
||||
},
|
||||
"node_modules/cross-spawn": {
|
||||
@@ -2388,6 +2457,7 @@
|
||||
"integrity": "sha512-BhHmn2yNOFA9H9JmmIVKJmd288g9hrVRDkdoIgRCRuSySRUHH7r/DI6aAXW9T1WwUuY3DFgrcaqB+deURBLR5g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.8.0",
|
||||
"@eslint-community/regexpp": "^4.12.1",
|
||||
@@ -3267,9 +3337,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/minizlib": {
|
||||
"version": "3.0.2",
|
||||
"resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.0.2.tgz",
|
||||
"integrity": "sha512-oG62iEk+CYt5Xj2YqI5Xi9xWUeZhDI8jjQmC5oThVH5JGCTgIjr7ciJDzC7MBzYd//WvR1OTmP5Q38Q8ShQtVA==",
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.1.0.tgz",
|
||||
"integrity": "sha512-KZxYo1BUkWD2TVFLr0MQoM8vUUigWD3LlD83a/75BqC+4qE0Hb1Vo5v1FgcfaNXvfXzr+5EhQ6ing/CaBijTlw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -3279,22 +3349,6 @@
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/mkdirp": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-3.0.1.tgz",
|
||||
"integrity": "sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"mkdirp": "dist/cjs/src/bin.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/isaacs"
|
||||
}
|
||||
},
|
||||
"node_modules/ms": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz",
|
||||
@@ -3513,6 +3567,7 @@
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-19.1.0.tgz",
|
||||
"integrity": "sha512-FS+XFBNvn3GTAWq26joslQgWNoFu08F4kl0J4CgdNKADkdSGXQyTCnKteIAJy96Br6YbpEU1LSzV5dYtjMkMDg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
@@ -3522,6 +3577,7 @@
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.1.0.tgz",
|
||||
"integrity": "sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"scheduler": "^0.26.0"
|
||||
},
|
||||
@@ -3559,9 +3615,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/react-router": {
|
||||
"version": "7.6.2",
|
||||
"resolved": "https://registry.npmjs.org/react-router/-/react-router-7.6.2.tgz",
|
||||
"integrity": "sha512-U7Nv3y+bMimgWjhlT5CRdzHPu2/KVmqPwKUCChW8en5P3znxUqwlYFlbmyj8Rgp1SF6zs5X4+77kBVknkg6a0w==",
|
||||
"version": "7.12.0",
|
||||
"resolved": "https://registry.npmjs.org/react-router/-/react-router-7.12.0.tgz",
|
||||
"integrity": "sha512-kTPDYPFzDVGIIGNLS5VJykK0HfHLY5MF3b+xj0/tTyNYL1gF1qs7u67Z9jEhQk2sQ98SUaHxlG31g1JtF7IfVw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"cookie": "^1.0.1",
|
||||
@@ -3581,12 +3637,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/react-router-dom": {
|
||||
"version": "7.6.2",
|
||||
"resolved": "https://registry.npmjs.org/react-router-dom/-/react-router-dom-7.6.2.tgz",
|
||||
"integrity": "sha512-Q8zb6VlTbdYKK5JJBLQEN06oTUa/RAbG/oQS1auK1I0TbJOXktqm+QENEVJU6QvWynlXPRBXI3fiOQcSEA78rA==",
|
||||
"version": "7.12.0",
|
||||
"resolved": "https://registry.npmjs.org/react-router-dom/-/react-router-dom-7.12.0.tgz",
|
||||
"integrity": "sha512-pfO9fiBcpEfX4Tx+iTYKDtPbrSLLCbwJ5EqP+SPYQu1VYCXdy79GSj0wttR0U4cikVdlImZuEZ/9ZNCgoaxwBA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"react-router": "7.6.2"
|
||||
"react-router": "7.12.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
@@ -3705,9 +3761,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/set-cookie-parser": {
|
||||
"version": "2.7.1",
|
||||
"resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-2.7.1.tgz",
|
||||
"integrity": "sha512-IOc8uWeOZgnb3ptbCURJWNjWUPcO3ZnTTdzsurqERrP6nPyv+paC55vJM0LpOlT2ne+Ix+9+CRG1MNLlyZ4GjQ==",
|
||||
"version": "2.7.2",
|
||||
"resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-2.7.2.tgz",
|
||||
"integrity": "sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/shebang-command": {
|
||||
@@ -3787,17 +3843,16 @@
|
||||
}
|
||||
},
|
||||
"node_modules/tar": {
|
||||
"version": "7.4.3",
|
||||
"resolved": "https://registry.npmjs.org/tar/-/tar-7.4.3.tgz",
|
||||
"integrity": "sha512-5S7Va8hKfV7W5U6g3aYxXmlPoZVAwUMy9AOKyF2fVuZa2UD3qZjg578OrLRt8PcNN1PleVaL/5/yYATNL0ICUw==",
|
||||
"version": "7.5.6",
|
||||
"resolved": "https://registry.npmjs.org/tar/-/tar-7.5.6.tgz",
|
||||
"integrity": "sha512-xqUeu2JAIJpXyvskvU3uvQW8PAmHrtXp2KDuMJwQqW8Sqq0CaZBAQ+dKS3RBXVhU4wC5NjAdKrmh84241gO9cA==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"license": "BlueOak-1.0.0",
|
||||
"dependencies": {
|
||||
"@isaacs/fs-minipass": "^4.0.0",
|
||||
"chownr": "^3.0.0",
|
||||
"minipass": "^7.1.2",
|
||||
"minizlib": "^3.0.1",
|
||||
"mkdirp": "^3.0.1",
|
||||
"minizlib": "^3.1.0",
|
||||
"yallist": "^5.0.0"
|
||||
},
|
||||
"engines": {
|
||||
@@ -3852,6 +3907,7 @@
|
||||
"integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -3904,6 +3960,7 @@
|
||||
"integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3982,6 +4039,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
@@ -4072,6 +4130,7 @@
|
||||
"integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@
|
||||
"react-dom": "^19.1.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-resizable-panels": "^3.0.4",
|
||||
"react-router-dom": "^7.6.2"
|
||||
"react-router-dom": "^7.12.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.25.0",
|
||||
|
||||
@@ -10,6 +10,7 @@ export interface Model {
|
||||
name: string;
|
||||
description: string;
|
||||
unlisted: boolean;
|
||||
peerID: string;
|
||||
}
|
||||
|
||||
interface APIProviderType {
|
||||
@@ -70,7 +71,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
const [versionInfo, setVersionInfo] = useState<VersionInfo>({
|
||||
build_date: "unknown",
|
||||
commit: "unknown",
|
||||
version: "unknown"
|
||||
version: "unknown",
|
||||
});
|
||||
//const apiEventSource = useRef<EventSource | null>(null);
|
||||
|
||||
@@ -166,7 +167,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// fetch version
|
||||
// fetch version
|
||||
const fetchVersion = async () => {
|
||||
try {
|
||||
const response = await fetch("/api/version");
|
||||
@@ -180,7 +181,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
}
|
||||
};
|
||||
|
||||
if (connectionStatus === 'connected') {
|
||||
if (connectionStatus === "connected") {
|
||||
fetchVersion();
|
||||
}
|
||||
}, [connectionStatus]);
|
||||
@@ -265,7 +266,19 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
||||
connectionStatus,
|
||||
versionInfo,
|
||||
}),
|
||||
[models, listModels, unloadAllModels, unloadSingleModel, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics, connectionStatus, versionInfo]
|
||||
[
|
||||
models,
|
||||
listModels,
|
||||
unloadAllModels,
|
||||
unloadSingleModel,
|
||||
loadModel,
|
||||
enableAPIEvents,
|
||||
proxyLogs,
|
||||
upstreamLogs,
|
||||
metrics,
|
||||
connectionStatus,
|
||||
versionInfo,
|
||||
]
|
||||
);
|
||||
|
||||
return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
|
||||
|
||||
+49
-16
@@ -44,8 +44,24 @@ function ModelsPanel() {
|
||||
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
|
||||
const filteredModels = useMemo(() => {
|
||||
return models.filter((model) => showUnlisted || !model.unlisted);
|
||||
const { regularModels, peerModelsByPeerId } = useMemo(() => {
|
||||
const filtered = models.filter((model) => showUnlisted || !model.unlisted);
|
||||
const peerModels = filtered.filter((m) => m.peerID);
|
||||
|
||||
// Group peer models by peerID
|
||||
const grouped = peerModels.reduce((acc, model) => {
|
||||
const peerId = model.peerID || "unknown";
|
||||
if (!acc[peerId]) {
|
||||
acc[peerId] = [];
|
||||
}
|
||||
acc[peerId].push(model);
|
||||
return acc;
|
||||
}, {} as Record<string, typeof peerModels>);
|
||||
|
||||
return {
|
||||
regularModels: filtered.filter((m) => !m.peerID),
|
||||
peerModelsByPeerId: grouped,
|
||||
};
|
||||
}, [models, showUnlisted]);
|
||||
|
||||
const handleUnloadAllModels = useCallback(async () => {
|
||||
@@ -151,7 +167,7 @@ function ModelsPanel() {
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{filteredModels.map((model) => (
|
||||
{regularModels.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td className={`${model.unlisted ? "text-txtsecondary" : ""}`}>
|
||||
<a href={`/upstream/${model.id}/`} className="font-semibold" target="_blank">
|
||||
@@ -186,6 +202,34 @@ function ModelsPanel() {
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{Object.keys(peerModelsByPeerId).length > 0 && (
|
||||
<>
|
||||
<h3 className="mt-8 mb-2">Peer Models</h3>
|
||||
{Object.entries(peerModelsByPeerId)
|
||||
.sort(([a], [b]) => a.localeCompare(b))
|
||||
.map(([peerId, models]) => (
|
||||
<div key={peerId} className="mb-4">
|
||||
<table className="w-full">
|
||||
<thead className="sticky top-0 bg-card z-10">
|
||||
<tr className="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
|
||||
<th className="font-semibold">{peerId}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{models.map((model) => (
|
||||
<tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200">
|
||||
<td className={`pl-8 ${model.unlisted ? "text-txtsecondary" : ""}`}>
|
||||
<span>{model.id}</span>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -223,11 +267,7 @@ function TokenHistogram({ data }: { data: HistogramData }) {
|
||||
|
||||
return (
|
||||
<div className="mt-2 w-full">
|
||||
<svg
|
||||
viewBox={`0 0 ${viewBoxWidth} ${height}`}
|
||||
className="w-full h-auto"
|
||||
preserveAspectRatio="xMidYMid meet"
|
||||
>
|
||||
<svg viewBox={`0 0 ${viewBoxWidth} ${height}`} className="w-full h-auto" preserveAspectRatio="xMidYMid meet">
|
||||
{/* Y-axis */}
|
||||
<line
|
||||
x1={padding.left}
|
||||
@@ -312,14 +352,7 @@ function TokenHistogram({ data }: { data: HistogramData }) {
|
||||
/>
|
||||
|
||||
{/* X-axis labels */}
|
||||
<text
|
||||
x={padding.left}
|
||||
y={height - 5}
|
||||
fontSize="10"
|
||||
fill="currentColor"
|
||||
opacity="0.6"
|
||||
textAnchor="start"
|
||||
>
|
||||
<text x={padding.left} y={height - 5} fontSize="10" fill="currentColor" opacity="0.6" textAnchor="start">
|
||||
{min.toFixed(1)}
|
||||
</text>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user