Compare commits

...

17 Commits

Author SHA1 Message Date
Benson Wong 4413881b2d proxy: actually add /v1/responses endpoint (#449)
ref: #448
2026-01-01 13:35:45 -08:00
Benson Wong 8df5e8563b proxy: add /v1/responses and /v1/audio/voices endpoints (#448)
Updates #433
Fixes #442 #226
2026-01-01 12:52:12 -08:00
Benson Wong 7931212d3e proxy: add v1/images/edits API endpoint (#447)
Updates #433
2026-01-01 12:43:06 -08:00
Benson Wong 3dc36032fb proxy: skip very slow tests in -short test mode (#446)
* proxy: skip very slow tests in -short test mode
* CLAUDE.md: update testing instructions
2025-12-31 14:08:56 -08:00
Benson Wong addb98646f proxy: add support for basic authorization (#445)
Fixes #444 where the UI with api keys did not work. The choice to use
http basic authorization is for simple, automatic browser support. No
changes to the UI were necessary. Just use an API key as the password,
no user name is required.
2025-12-31 13:42:35 -08:00
Benson Wong 37d74efc2d proxy: add /v1/images/generations (#443)
Add support for the /v1/images/generations endpoint

Updates #433
Closes #191
2025-12-30 21:04:58 -08:00
Benson Wong 22e098ac8b Add Peer Model Support (#438)
This PR allows a single llama-swap to be the central proxy for models served by other inference servers. The peer servers can be another llama-swap or any API that supports the /v1/* inference endpoint.

Updates: #433, #299
Closes: #296
2025-12-27 20:18:06 -08:00
Benson Wong 9864f9f517 .coderabbit.yaml: disable annoying features 2025-12-23 23:53:06 -08:00
Benson Wong 53b32f3601 proxy: add API key support (#436)
Add configuration support for api keys that are enforced by llama-swap. Keys are stripped before sending them to upstream servers. 

Updates: #433, #50 and #251
2025-12-23 23:39:33 -08:00
Benson Wong 565c44766d config,proxy: add new configuration logToStdout (#432)
The new logToStdout option controls what is logged to stdout. The
default has been changed to just the proxy logs, which contain swap and
http request logs.

There are four supported settings: none, proxy, upstream, both. The
"both" setting is the legacy setting where everything was spewed to
stdout.
2025-12-21 22:23:31 -08:00
Benson Wong e6a9e210ba proxy: fix path bug in /logs/stream/{model_id} (#431)
A {model_id} containing a forward slash trips up gin's path param
parsing. This updates /logs/stream to work like /upstream where the
model_id is built up in parts and searched for in the configuration.

Updates #421
2025-12-21 21:47:14 -08:00
Benson Wong d3f329f924 proxy: Improve logging performance and allow separate log streaming (#421)
Replace container/ring.Ring with a custom circularBuffer that uses a
single contiguous []byte slice. This fixes the original implementation
which created 10,240 ring elements instead of 10KB of storage.

GetHistory is now 139x faster (145μs → 1μs) and uses 117x less memory
(1.2MB → 10KB). Allocations reduced from 2 to 1 per write operation.

Create a LogMonitor per proxy.Process, replacing the usage
of a shared one. The buffer in LogMonitor is lazy allocated on the first
call to Write and freed when the Process is stopped. This reduces
unnecessary memory usage when a model is not active.

The /logs/stream/{model_id} endpoint was added to stream logs from a
specific process.
2025-12-18 21:49:25 -08:00
Benson Wong 98879b38c1 docker: add /app to $PATH (#424)
Make it so llama-server can be called directly instead of with the full
path at /app/llama-server.

Fixes #423
Ref: #233
2025-12-06 22:58:29 -08:00
Benson Wong 7b3b0f5eae move header images around [skip ci] 2025-12-02 19:40:42 -08:00
Benson Wong 021ccceef1 README: update hero image 2025-12-02 19:37:03 -08:00
Benson Wong f03871c50a Update README.md
- add supported anthropic API 
- add example for docker hot reload support
2025-12-02 19:03:01 -08:00
Ryan Steed dc00d17abe docs: add documentation for non-root container images and security considerations (#416)
* docs: add documentation for non-root container images and security considerations
* docs: move container security section to dedicated file and update README links
2025-12-02 08:52:26 -08:00
34 changed files with 2317 additions and 254 deletions
+7
View File
@@ -8,8 +8,15 @@ reviews:
poem: false poem: false
review_status: true review_status: true
collapse_walkthrough: false collapse_walkthrough: false
sequence_diagrams: false
finishing_touches:
docstrings:
enabled: false
auto_review: auto_review:
enabled: true enabled: true
drafts: false drafts: false
chat: chat:
auto_reply: true auto_reply: true
issue_enrichment:
planning:
enabled: false
+4 -2
View File
@@ -11,8 +11,10 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
## Testing ## Testing
- `make test-dev` - Use this when making iterative changes. Runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory - Follow test naming conventions like `TestProxyManager_<test name>`, `TestProcessGroup_<test name>`, etc.
- `make test-all` - runs at the end before completing work. Includes long running concurrency tests. - Use `go test -v -run <name pattern for new tests>` to run any new tests you've written.
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests.
## Workflow Tasks ## Workflow Tasks
+32 -12
View File
@@ -1,4 +1,4 @@
![llama-swap header image](header2.png) ![llama-swap header image](docs/assets/hero3.webp)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total) ![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml) ![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml)
![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap) ![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap)
@@ -13,14 +13,20 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies - ✅ Easy to deploy and configure: one binary, one configuration file. no external dependencies
- ✅ On-demand model switching - ✅ On-demand model switching
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc.)
- future proof, upgrade your inference servers at any time. - future proof, upgrade your inference servers at any time.
- ✅ OpenAI API supported endpoints: - ✅ OpenAI API supported endpoints:
- `v1/completions` - `v1/completions`
- `v1/chat/completions` - `v1/chat/completions`
- `v1/responses`
- `v1/embeddings` - `v1/embeddings`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) - `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- `v1/audio/voices`
- `v1/images/generations`
- `v1/images/edits`
- ✅ Anthropic API supported endpoints:
- `v1/messages`
- ✅ llama-server (llama.cpp) supported endpoints - ✅ llama-server (llama.cpp) supported endpoints
- `v1/rerank`, `v1/reranking`, `/rerank` - `v1/rerank`, `v1/reranking`, `/rerank`
- `/infill` - for code infilling - `/infill` - for code infilling
@@ -32,6 +38,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/log` - remote log monitoring - `/log` - remote log monitoring
- `/health` - just returns "OK" - `/health` - just returns "OK"
- ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable - ✅ Customizable
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- Automatic unloading of models after timeout by setting a `ttl` - Automatic unloading of models after timeout by setting a `ttl`
@@ -44,7 +51,6 @@ llama-swap includes a real time web interface for monitoring logs and controllin
<img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" /> <img width="1164" height="745" alt="image" src="https://github.com/user-attachments/assets/bacf3f9d-819f-430b-9ed2-1bfaa8d54579" />
The Activity Page shows recent requests: The Activity Page shows recent requests:
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" /> <img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
@@ -61,7 +67,7 @@ llama-swap can be installed in multiple ways
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap)) ### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc). Nightly container images with llama-swap and llama-server are built for multiple platforms (cuda, vulkan, intel, etc.) including [non-root variants with improved security](docs/container-security.md).
```shell ```shell
$ docker pull ghcr.io/mostlygeek/llama-swap:cuda $ docker pull ghcr.io/mostlygeek/llama-swap:cuda
@@ -71,6 +77,14 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \ -v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \ -v /path/to/custom/config.yaml:/app/config.yaml \
ghcr.io/mostlygeek/llama-swap:cuda ghcr.io/mostlygeek/llama-swap:cuda
# configuration hot reload supported with a
# directory volume mount
$ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \
-v /path/to/config:/config \
ghcr.io/mostlygeek/llama-swap:cuda -config /config/config.yaml -watch-config
``` ```
<details> <details>
@@ -89,6 +103,9 @@ docker pull ghcr.io/mostlygeek/llama-swap:musa
# tagged llama-swap, platform and llama-server version images # tagged llama-swap, platform and llama-server version images
docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795 docker pull ghcr.io/mostlygeek/llama-swap:v166-cuda-b6795
# non-root cuda
docker pull ghcr.io/mostlygeek/llama-swap:cuda-non-root
``` ```
</details> </details>
@@ -191,23 +208,26 @@ As a safeguard, llama-swap also sets `X-Accel-Buffering: no` on SSE responses. H
## Monitoring Logs on the CLI ## Monitoring Logs on the CLI
```shell ```sh
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' $ curl http://host/logs
# streams combined logs # streams combined logs
curl -Ns 'http://host/logs/stream' curl -Ns http://host/logs/stream
# just llama-swap's logs # stream llama-swap's proxy status logs
curl -Ns 'http://host/logs/stream/proxy' curl -Ns http://host/logs/stream/proxy
# just upstream's logs # stream logs from upstream processes that llama-swap loads
curl -Ns 'http://host/logs/stream/upstream' curl -Ns http://host/logs/stream/upstream
# stream logs only from a specific model
curl -Ns http://host/logs/stream/{model_id}
# stream and filter logs with linux pipes # stream and filter logs with linux pipes
curl -Ns http://host/logs/stream | grep 'eval time' curl -Ns http://host/logs/stream | grep 'eval time'
# skips history and just streams new log entries # appending ?no-history will disable sending buffered history first
curl -Ns 'http://host/logs/stream?no-history' curl -Ns 'http://host/logs/stream?no-history'
``` ```
@@ -0,0 +1,85 @@
# Replace ring.Ring with Efficient Circular Byte Buffer
## Overview
Replace the inefficient `container/ring.Ring` implementation in `logMonitor.go` with a simple circular byte buffer that uses a single contiguous `[]byte` slice. This eliminates per-write allocations, improves cache locality, and correctly implements a 10KB buffer.
## Current Issues
1. `ring.New(10 * 1024)` creates 10,240 ring **elements**, not 10KB of storage
2. Every `Write()` call allocates a new `[]byte` slice inside the lock
3. `GetHistory()` iterates all 10,240 elements and appends repeatedly (geometric reallocs)
4. Linked list structure has poor cache locality and pointer overhead
## Design Requirements
### New CircularBuffer Type
Create a simple circular byte buffer with:
- Single pre-allocated `[]byte` of fixed capacity (10KB)
- `head` and `size` integers to track write position and data length
- No per-write allocations
### API Requirements
The new buffer must support:
1. **Write(p []byte)** - Append bytes, overwriting oldest data when full
2. **GetHistory() []byte** - Return all buffered data in correct order (oldest to newest)
### Implementation Details
```go
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
}
```
**Write logic:**
- If `len(p) >= capacity`: just keep the last `capacity` bytes
- Otherwise: write bytes at `head`, wrapping around if needed
- Update `head` and `size` accordingly
- Data is copied into the internal buffer (not stored by reference)
**GetHistory logic:**
- Calculate start position: `(head - size + cap) % cap`
- If not wrapped: single slice copy
- If wrapped: two copies (end of buffer + beginning)
- Returns a **new slice** (copy), not a view into internal buffer
### Immutability Guarantees (must preserve)
Per existing tests:
1. Modifying input `[]byte` after `Write()` must not affect stored data
2. `GetHistory()` returns independent copy - modifications don't affect buffer
## Files to Modify
- `proxy/logMonitor.go` - Replace `buffer *ring.Ring` with new circular buffer
## Testing Plan
Existing tests in `logMonitor_test.go` should continue to pass:
- `TestLogMonitor` - Basic write/read and subscriber notification
- `TestWrite_ImmutableBuffer` - Verify writes don't affect returned history
- `TestWrite_LogTimeFormat` - Timestamp formatting
Add new tests:
- Test buffer wrap-around behavior
- Test large writes that exceed buffer capacity
- Test exact capacity boundary conditions
## Checklist
- [ ] Create `circularBuffer` struct in `logMonitor.go`
- [ ] Implement `Write()` method for circular buffer
- [ ] Implement `GetHistory()` method for circular buffer
- [ ] Update `LogMonitor` struct to use new buffer
- [ ] Update `NewLogMonitorWriter()` to initialize new buffer
- [ ] Update `LogMonitor.Write()` to use new buffer
- [ ] Update `LogMonitor.GetHistory()` to use new buffer
- [ ] Remove `"container/ring"` import
- [ ] Run `make test-dev` to verify existing tests pass
- [ ] Add wrap-around test case
- [ ] Run `make test-all` for final validation
+52
View File
@@ -273,6 +273,58 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"description": "A dictionary of event triggers and actions. Only supported hook is on_startup." "description": "A dictionary of event triggers and actions. Only supported hook is on_startup."
},
"logToStdout": {
"type": "string",
"enum": [
"proxy",
"upstream",
"both",
"none"
],
"default": "proxy",
"description": "Controls what is logged to stdout. 'proxy': logs generated by llama-swap, 'upstream': copy of upstream process stdout logs, 'both': both interleaved together, 'none': no logs written to stdout."
},
"apiKeys": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"default": [],
"description": "Require an API key when making requests to inference endpoints. When empty, authorization will not be checked. Each key is a non-empty string."
},
"peers": {
"type": "object",
"additionalProperties": {
"type": "object",
"required": [
"proxy",
"models"
],
"properties": {
"proxy": {
"type": "string",
"format": "uri",
"description": "A valid base URL to proxy requests to. Requested path to llama-swap will be appended to the end of the proxy value."
},
"apiKey": {
"type": "string",
"default": "",
"description": "A string key to be injected into the request. If blank, no key will be added. Key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>."
},
"models": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"description": "A list of models served by the peer."
}
}
},
"default": {},
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
} }
} }
} }
+52
View File
@@ -34,6 +34,16 @@ logLevel: info
# - For more info, read: https://pkg.go.dev/time#pkg-constants # - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: "" logTimeFormat: ""
# logToStdout: controls what is logged to stdout
# - optional, default: "proxy"
# - valid values:
# - "proxy": logs generated by llama-swap when swapping models,
# handling requests, etc.
# - "upstream": a copy of an upstream processes stdout logs
# - "both": both the proxy and upstream logs interleaved together
# - "none": no logs are ever written to stdout
logToStdout: "proxy"
# metricsMaxInMemory: maximum number of metrics to keep in memory # metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000 # - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded # - controls how many metrics are stored in memory before older ones are discarded
@@ -60,6 +70,16 @@ sendLoadingState: true
# all fields except for Id so chat UIs can use the alias equivalent to the original. # all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false includeAliasesInList: false
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
# macros: a dictionary of string substitutions # macros: a dictionary of string substitutions
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - macros are reusable snippets # - macros are reusable snippets
@@ -321,3 +341,35 @@ hooks:
# otherwise models will be loaded and swapped out # otherwise models will be loaded and swapped out
preload: preload:
- "llama" - "llama"
# peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary
# - peers can be another llama-swap
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
peers:
# keys is the peer'd ID
llama-swap-peer:
# proxy: a valid base URL to proxy requests to
# - required
# - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23
# models: a list of models served by the peer
# - required
models:
- model_a
- model_b
- embeddings/model_c
openrouter:
proxy: https://openrouter.ai/api
# apiKey: a string key to be injected into the request
# - optional, default: ""
# - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
apiKey: sk-your-openrouter-key
models:
- meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507
- deepseek/deepseek-v3.2
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
+4
View File
@@ -29,6 +29,10 @@ RUN chown --recursive $UID:$GID $HOME /app
USER $UID:$GID USER $UID:$GID
WORKDIR /app WORKDIR /app
# Add /app to PATH
ENV PATH="/app:${PATH}"
RUN \ RUN \
curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \ curl -LO "https://github.com/${LS_REPO}/releases/download/v${LS_VER}/llama-swap_${LS_VER}_linux_amd64.tar.gz" && \
tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \ tar -zxf "llama-swap_${LS_VER}_linux_amd64.tar.gz" && \

Before

Width:  |  Height:  |  Size: 261 KiB

After

Width:  |  Height:  |  Size: 261 KiB

Before

Width:  |  Height:  |  Size: 351 KiB

After

Width:  |  Height:  |  Size: 351 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

+82 -1
View File
@@ -86,9 +86,12 @@ llama-swap supports many more features to customize how you want to manage your
## Full Configuration Example ## Full Configuration Example
> [!NOTE] > [!NOTE]
> This is a copy of `config.example.yaml`. Always check that for the most up to date examples. > Always check [config.example.yaml](https://github.com/mostlygeek/llama-swap/blob/main/config.example.yaml) for the most up to date reference for all example configurations.
```yaml ```yaml
# add this modeline for validation in vscode
# yaml-language-server: $schema=https://raw.githubusercontent.com/mostlygeek/llama-swap/refs/heads/main/config-schema.json
#
# llama-swap YAML configuration example # llama-swap YAML configuration example
# ------------------------------------- # -------------------------------------
# #
@@ -114,6 +117,24 @@ healthCheckTimeout: 500
# - Valid log levels: debug, info, warn, error # - Valid log levels: debug, info, warn, error
logLevel: info logLevel: info
# logTimeFormat: enables and sets the logging timestamp format
# - optional, default (disabled): ""
# - Valid values: "", "ansic", "unixdate", "rubydate", "rfc822", "rfc822z",
# "rfc850", "rfc1123", "rfc1123z", "rfc3339", "rfc3339nano", "kitchen",
# "stamp", "stampmilli", "stampmicro", and "stampnano".
# - For more info, read: https://pkg.go.dev/time#pkg-constants
logTimeFormat: ""
# logToStdout: controls what is logged to stdout
# - optional, default: "proxy"
# - valid values:
# - "proxy": logs generated by llama-swap when swapping models,
# handling requests, etc.
# - "upstream": a copy of an upstream processes stdout logs
# - "both": both the proxy and upstream logs interleaved together
# - "none": no logs are ever written to stdout
logToStdout: "proxy"
# metricsMaxInMemory: maximum number of metrics to keep in memory # metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000 # - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded # - controls how many metrics are stored in memory before older ones are discarded
@@ -126,6 +147,30 @@ metricsMaxInMemory: 1000
# - it is automatically incremented for every model that uses it # - it is automatically incremented for every model that uses it
startPort: 10001 startPort: 10001
# sendLoadingState: inject loading status updates into the reasoning (thinking)
# field
# - optional, default: false
# - when true, a stream of loading messages will be sent to the client in the
# reasoning field so chat UIs can show that loading is in progress.
# - see #366 for more details
sendLoadingState: true
# includeAliasesInList: present aliases within the /v1/models OpenAI API listing
# - optional, default: false
# - when true, model aliases will be output to the API model listing duplicating
# all fields except for Id so chat UIs can use the alias equivalent to the original.
includeAliasesInList: false
# apiKeys: require an API key when making requests to inference endpoints
# - optional, default: []
# - when empty (the default) authorization will not be checked as llama-swap is default-allow
# - each key is a non-empty string
apiKeys:
- "sk-hunter2"
# hint, one liner: printf "sk-%s\n" "$(head -c 48 /dev/urandom | base64 )"
- "sk-gyCPiKUcIfPlaM4OSMZekkprgijPx6+OsmQs8Rsg0xZ9qpy6gKWsIKqHOk+cgXVx"
- "sk-+QtIn0Zjj4UHjiaZYiZEnru4mrwKM9RzhmJeK5SobNXLl8QMFXxGz1/2lEuvQpkb"
# macros: a dictionary of string substitutions # macros: a dictionary of string substitutions
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - macros are reusable snippets # - macros are reusable snippets
@@ -274,6 +319,10 @@ models:
# - recommended to be omitted and the default used # - recommended to be omitted and the default used
concurrencyLimit: 0 concurrencyLimit: 0
# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
# Unlisted model example: # Unlisted model example:
"qwen-unlisted": "qwen-unlisted":
# unlisted: boolean, true or false # unlisted: boolean, true or false
@@ -383,4 +432,36 @@ hooks:
# otherwise models will be loaded and swapped out # otherwise models will be loaded and swapped out
preload: preload:
- "llama" - "llama"
# peers: a dictionary of remote peers and models they provide
# - optional, default empty dictionary
# - peers can be another llama-swap
# - peers can be any server that provides the /v1/ generative api endpoints supported by llama-swap
peers:
# keys is the peer'd ID
llama-swap-peer:
# proxy: a valid base URL to proxy requests to
# - required
# - requested path to llama-swap will be appended to the end of the proxy value
proxy: http://192.168.1.23
# models: a list of models served by the peer
# - required
models:
- model_a
- model_b
- embeddings/model_c
openrouter:
proxy: https://openrouter.ai/api
# apiKey: a string key to be injected into the request
# - optional, default: ""
# - if blank, no key will be added to the request
# - key will be injected into headers: Authorization: Bearer <key> and x-api-key: <key>
apiKey: sk-your-openrouter-key
models:
- meta-llama/llama-3.1-8b-instruct
- qwen/qwen3-235b-a22b-2507
- deepseek/deepseek-v3.2
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
``` ```
+9
View File
@@ -0,0 +1,9 @@
## Container Security
For convenience, the default container images use the **root** user within the container. This permits simplified access to host resources including volume mounts and hardware devices under `/dev/dri` (_for Vulkan support_). But this can widen the attack surface to privilege escalation exploits.
Alternative images, tagged as `non-root`, are also available. For example, `llama-swap:cpu-non-root` uses the unprivileged **app** user by default. Depending on deployment requirements, additional configuration may be necessary to ensure that the container retains access to required hosts resources. This might entail customizing host filesystem permissions/ownership appropriately or injecting host group membership into the container.
Docker offers a [system-wide option enabling user namespace remapping](https://docs.docker.com/engine/security/userns-remap/) to accommodate situations were a **root** container user is required but also mentions that _"The best way to prevent privilege-escalation attacks from within a container is to configure your container's applications to run as unprivileged users."_ Podman offers similar capability, per-container, to [set UID/GID mapping in a new user namespace](https://docs.podman.io/en/latest/markdown/podman-run.1.html#set-uid-gid-mapping-in-a-new-user-namespace).
The Large Language Model (_LLM/AI_) ecosystem is rapidly evolving and [serious security vulnerabilities have surfaced in the past](https://huggingface.co/docs/hub/security-pickle). These alternative _non-root_ images could reduce the impact of future unknown problems. However, proper planning and configuration is recommended to utilize them.
+31
View File
@@ -15,6 +15,12 @@ import (
) )
const DEFAULT_GROUP_ID = "(default)" const DEFAULT_GROUP_ID = "(default)"
const (
LogToStdoutProxy = "proxy"
LogToStdoutUpstream = "upstream"
LogToStdoutBoth = "both"
LogToStdoutNone = "none"
)
type MacroEntry struct { type MacroEntry struct {
Name string Name string
@@ -114,6 +120,7 @@ type Config struct {
LogRequests bool `yaml:"logRequests"` LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"` LogLevel string `yaml:"logLevel"`
LogTimeFormat string `yaml:"logTimeFormat"` LogTimeFormat string `yaml:"logTimeFormat"`
LogToStdout string `yaml:"logToStdout"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"` MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */ Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
@@ -136,6 +143,12 @@ type Config struct {
// present aliases to /v1/models OpenAI API listing // present aliases to /v1/models OpenAI API listing
IncludeAliasesInList bool `yaml:"includeAliasesInList"` IncludeAliasesInList bool `yaml:"includeAliasesInList"`
// support API keys, see issue #433, #50, #251
RequiredAPIKeys []string `yaml:"apiKeys"`
// support remote peers, see issue #433, #296
Peers PeerDictionaryConfig `yaml:"peers"`
} }
func (c *Config) RealModelName(search string) (string, bool) { func (c *Config) RealModelName(search string) (string, bool) {
@@ -177,6 +190,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
StartPort: 5800, StartPort: 5800,
LogLevel: "info", LogLevel: "info",
LogTimeFormat: "", LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
MetricsMaxInMemory: 1000, MetricsMaxInMemory: 1000,
} }
err = yaml.Unmarshal(data, &config) err = yaml.Unmarshal(data, &config)
@@ -193,6 +207,12 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
return Config{}, fmt.Errorf("startPort must be greater than 1") return Config{}, fmt.Errorf("startPort must be greater than 1")
} }
switch config.LogToStdout {
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
default:
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
}
// Populate the aliases map // Populate the aliases map
config.aliases = make(map[string]string) config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models { for modelName, modelConfig := range config.Models {
@@ -404,6 +424,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
config.Hooks.OnStartup.Preload = toPreload config.Hooks.OnStartup.Preload = toPreload
} }
// check api keys validatity
for _, apikey := range config.RequiredAPIKeys {
if apikey == "" {
return Config{}, fmt.Errorf("empty api key found in apiKeys")
}
if strings.Contains(apikey, " ") {
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
}
}
return config, nil return config, nil
} }
+1
View File
@@ -166,6 +166,7 @@ groups:
expected := Config{ expected := Config{
LogLevel: "info", LogLevel: "info",
LogTimeFormat: "", LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
StartPort: 5800, StartPort: 5800,
Macros: MacroList{ Macros: MacroList{
{"svr-path", "path/to/server"}, {"svr-path", "path/to/server"},
+48
View File
@@ -761,3 +761,51 @@ models:
}) })
} }
} }
func TestConfig_APIKeys_Invalid(t *testing.T) {
tests := []struct {
name string
content string
expectedErr string
}{
{
name: "empty string",
content: `apiKeys: [""]`,
expectedErr: "empty api key found in apiKeys",
},
{
name: "blank spaces only",
content: `apiKeys: [" "]`,
expectedErr: "api key cannot contain spaces: ` `",
},
{
name: "contains leading space",
content: `apiKeys: [" key123"]`,
expectedErr: "api key cannot contain spaces: ` key123`",
},
{
name: "contains trailing space",
content: `apiKeys: ["key123 "]`,
expectedErr: "api key cannot contain spaces: `key123 `",
},
{
name: "contains middle space",
content: `apiKeys: ["key 123"]`,
expectedErr: "api key cannot contain spaces: `key 123`",
},
{
name: "empty in list with valid keys",
content: `apiKeys: ["valid-key", "", "another-key"]`,
expectedErr: "empty api key found in apiKeys",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.content))
if assert.Error(t, err) {
assert.Equal(t, tt.expectedErr, err.Error())
}
})
}
}
+1
View File
@@ -158,6 +158,7 @@ groups:
expected := Config{ expected := Config{
LogLevel: "info", LogLevel: "info",
LogTimeFormat: "", LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
StartPort: 5800, StartPort: 5800,
Macros: MacroList{ Macros: MacroList{
{"svr-path", "path/to/server"}, {"svr-path", "path/to/server"},
+47
View File
@@ -0,0 +1,47 @@
package config
import (
"fmt"
"net/url"
)
type PeerDictionaryConfig map[string]PeerConfig
type PeerConfig struct {
Proxy string `yaml:"proxy"`
ProxyURL *url.URL `yaml:"-"`
ApiKey string `yaml:"apiKey"`
Models []string `yaml:"models"`
}
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawPeerConfig PeerConfig
defaults := rawPeerConfig{
Proxy: "",
ApiKey: "",
Models: []string{},
}
if err := unmarshal(&defaults); err != nil {
return err
}
// Validate proxy is not empty
if defaults.Proxy == "" {
return fmt.Errorf("proxy is required")
}
// Validate proxy is a valid URL and store the parsed value
parsedURL, err := url.Parse(defaults.Proxy)
if err != nil {
return fmt.Errorf("invalid peer proxy URL (%s): %w", defaults.Proxy, err)
}
defaults.ProxyURL = parsedURL
// Validate models is not empty
if len(defaults.Models) == 0 {
return fmt.Errorf("peer models can not be empty")
}
*c = PeerConfig(defaults)
return nil
}
+139
View File
@@ -0,0 +1,139 @@
package config
import (
"testing"
"gopkg.in/yaml.v3"
)
func TestPeerConfig_UnmarshalYAML(t *testing.T) {
tests := []struct {
name string
yaml string
wantErr string
}{
{
name: "valid config",
yaml: `
proxy: http://192.168.1.23
models:
- model_a
- model_b
`,
wantErr: "",
},
{
name: "valid config with apiKey",
yaml: `
proxy: https://openrouter.ai/api
apiKey: sk-test-key
models:
- meta-llama/llama-3.1-8b-instruct
`,
wantErr: "",
},
{
name: "missing proxy",
yaml: `
models:
- model_a
`,
wantErr: "proxy is required",
},
{
name: "empty proxy",
yaml: `
proxy: ""
models:
- model_a
`,
wantErr: "proxy is required",
},
{
name: "invalid proxy URL",
yaml: `
proxy: "://invalid"
models:
- model_a
`,
wantErr: "invalid peer proxy URL",
},
{
name: "missing models",
yaml: `
proxy: http://localhost:8080
`,
wantErr: "peer models can not be empty",
},
{
name: "empty models",
yaml: `
proxy: http://localhost:8080
models: []
`,
wantErr: "peer models can not be empty",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var config PeerConfig
err := yaml.Unmarshal([]byte(tt.yaml), &config)
if tt.wantErr == "" {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.wantErr)
} else if !contains(err.Error(), tt.wantErr) {
t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error())
}
}
})
}
}
func TestPeerConfig_ProxyURL(t *testing.T) {
yamlData := `
proxy: http://192.168.1.23:8080/api
apiKey: sk-test
models:
- model_a
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if config.ProxyURL == nil {
t.Fatal("ProxyURL should not be nil")
}
if config.ProxyURL.Host != "192.168.1.23:8080" {
t.Errorf("expected host %q, got %q", "192.168.1.23:8080", config.ProxyURL.Host)
}
if config.ProxyURL.Scheme != "http" {
t.Errorf("expected scheme %q, got %q", "http", config.ProxyURL.Scheme)
}
if config.ProxyURL.Path != "/api" {
t.Errorf("expected path %q, got %q", "/api", config.ProxyURL.Path)
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchSubstring(s, substr)
}
func searchSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
+101 -16
View File
@@ -1,7 +1,6 @@
package proxy package proxy
import ( import (
"container/ring"
"context" "context"
"fmt" "fmt"
"io" "io"
@@ -12,6 +11,85 @@ import (
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
) )
// circularBuffer is a fixed-size circular byte buffer that overwrites
// oldest data when full. It provides O(1) writes and O(n) reads.
type circularBuffer struct {
data []byte // pre-allocated capacity
head int // next write position
size int // current number of bytes stored (0 to cap)
}
func newCircularBuffer(capacity int) *circularBuffer {
return &circularBuffer{
data: make([]byte, capacity),
head: 0,
size: 0,
}
}
// Write appends bytes to the buffer, overwriting oldest data when full.
// Data is copied into the internal buffer (not stored by reference).
func (cb *circularBuffer) Write(p []byte) {
if len(p) == 0 {
return
}
cap := len(cb.data)
// If input is larger than capacity, only keep the last cap bytes
if len(p) >= cap {
copy(cb.data, p[len(p)-cap:])
cb.head = 0
cb.size = cap
return
}
// Calculate how much space is available from head to end of buffer
firstPart := cap - cb.head
if firstPart >= len(p) {
// All data fits without wrapping
copy(cb.data[cb.head:], p)
cb.head = (cb.head + len(p)) % cap
} else {
// Data wraps around
copy(cb.data[cb.head:], p[:firstPart])
copy(cb.data[:len(p)-firstPart], p[firstPart:])
cb.head = len(p) - firstPart
}
// Update size
cb.size += len(p)
if cb.size > cap {
cb.size = cap
}
}
// GetHistory returns all buffered data in correct order (oldest to newest).
// Returns a new slice (copy), not a view into internal buffer.
func (cb *circularBuffer) GetHistory() []byte {
if cb.size == 0 {
return nil
}
result := make([]byte, cb.size)
cap := len(cb.data)
// Calculate start position (oldest data)
start := (cb.head - cb.size + cap) % cap
if start+cb.size <= cap {
// Data is contiguous, single copy
copy(result, cb.data[start:start+cb.size])
} else {
// Data wraps around, two copies
firstPart := cap - start
copy(result[:firstPart], cb.data[start:])
copy(result[firstPart:], cb.data[:cb.size-firstPart])
}
return result
}
type LogLevel int type LogLevel int
const ( const (
@@ -19,12 +97,14 @@ const (
LevelInfo LevelInfo
LevelWarn LevelWarn
LevelError LevelError
LogBufferSize = 100 * 1024
) )
type LogMonitor struct { type LogMonitor struct {
eventbus *event.Dispatcher eventbus *event.Dispatcher
mu sync.RWMutex mu sync.RWMutex
buffer *ring.Ring buffer *circularBuffer
bufferMu sync.RWMutex bufferMu sync.RWMutex
// typically this can be os.Stdout // typically this can be os.Stdout
@@ -45,7 +125,7 @@ func NewLogMonitor() *LogMonitor {
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor { func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{ return &LogMonitor{
eventbus: event.NewDispatcherConfig(1000), eventbus: event.NewDispatcherConfig(1000),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs buffer: nil, // lazy initialized on first Write
stdout: stdout, stdout: stdout,
level: LevelInfo, level: LevelInfo,
prefix: "", prefix: "",
@@ -64,12 +144,15 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
} }
w.bufferMu.Lock() w.bufferMu.Lock()
bufferCopy := make([]byte, len(p)) if w.buffer == nil {
copy(bufferCopy, p) w.buffer = newCircularBuffer(LogBufferSize)
w.buffer.Value = bufferCopy }
w.buffer = w.buffer.Next() w.buffer.Write(p)
w.bufferMu.Unlock() w.bufferMu.Unlock()
// Make a copy for broadcast to preserve immutability
bufferCopy := make([]byte, len(p))
copy(bufferCopy, p)
w.broadcast(bufferCopy) w.broadcast(bufferCopy)
return n, nil return n, nil
} }
@@ -77,16 +160,18 @@ func (w *LogMonitor) Write(p []byte) (n int, err error) {
func (w *LogMonitor) GetHistory() []byte { func (w *LogMonitor) GetHistory() []byte {
w.bufferMu.RLock() w.bufferMu.RLock()
defer w.bufferMu.RUnlock() defer w.bufferMu.RUnlock()
if w.buffer == nil {
return nil
}
return w.buffer.GetHistory()
}
var history []byte // Clear releases the buffer memory, making it eligible for GC.
w.buffer.Do(func(p any) { // The buffer will be lazily re-allocated on the next Write.
if p != nil { func (w *LogMonitor) Clear() {
if content, ok := p.([]byte); ok { w.bufferMu.Lock()
history = append(history, content...) w.buffer = nil
} w.bufferMu.Unlock()
}
})
return history
} }
func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc { func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
+201
View File
@@ -113,3 +113,204 @@ func TestWrite_LogTimeFormat(t *testing.T) {
t.Fatalf("Cannot find timestamp: %v", err) t.Fatalf("Cannot find timestamp: %v", err)
} }
} }
func TestCircularBuffer_WrapAround(t *testing.T) {
// Create a small buffer to test wrap-around
cb := newCircularBuffer(10)
// Write "hello" (5 bytes)
cb.Write([]byte("hello"))
if got := string(cb.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Write "world" (5 bytes) - buffer now full
cb.Write([]byte("world"))
if got := string(cb.GetHistory()); got != "helloworld" {
t.Errorf("Expected 'helloworld', got %q", got)
}
// Write "12345" (5 bytes) - should overwrite "hello"
cb.Write([]byte("12345"))
if got := string(cb.GetHistory()); got != "world12345" {
t.Errorf("Expected 'world12345', got %q", got)
}
// Write data larger than buffer capacity
cb.Write([]byte("abcdefghijklmnop")) // 16 bytes, only last 10 kept
if got := string(cb.GetHistory()); got != "ghijklmnop" {
t.Errorf("Expected 'ghijklmnop', got %q", got)
}
}
func TestCircularBuffer_BoundaryConditions(t *testing.T) {
// Test empty buffer
cb := newCircularBuffer(10)
if got := cb.GetHistory(); got != nil {
t.Errorf("Expected nil for empty buffer, got %q", got)
}
// Test exact capacity
cb.Write([]byte("1234567890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
// Test write exactly at capacity boundary
cb = newCircularBuffer(10)
cb.Write([]byte("12345"))
cb.Write([]byte("67890"))
if got := string(cb.GetHistory()); got != "1234567890" {
t.Errorf("Expected '1234567890', got %q", got)
}
}
func TestLogMonitor_LazyInit(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Buffer should be nil before any writes
if lm.buffer != nil {
t.Error("Expected buffer to be nil before first write")
}
// GetHistory should return nil when buffer is nil
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history before first write, got %q", got)
}
// Write should lazily initialize the buffer
lm.Write([]byte("test"))
if lm.buffer == nil {
t.Error("Expected buffer to be initialized after write")
}
if got := string(lm.GetHistory()); got != "test" {
t.Errorf("Expected 'test', got %q", got)
}
}
func TestLogMonitor_Clear(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Write some data
lm.Write([]byte("hello"))
if got := string(lm.GetHistory()); got != "hello" {
t.Errorf("Expected 'hello', got %q", got)
}
// Clear should release the buffer
lm.Clear()
if lm.buffer != nil {
t.Error("Expected buffer to be nil after Clear")
}
if got := lm.GetHistory(); got != nil {
t.Errorf("Expected nil history after Clear, got %q", got)
}
}
func TestLogMonitor_ClearAndReuse(t *testing.T) {
lm := NewLogMonitorWriter(io.Discard)
// Write, clear, then write again
lm.Write([]byte("first"))
lm.Clear()
lm.Write([]byte("second"))
if got := string(lm.GetHistory()); got != "second" {
t.Errorf("Expected 'second' after clear and reuse, got %q", got)
}
}
func BenchmarkLogMonitorWrite(b *testing.B) {
// Test data of varying sizes
smallMsg := []byte("small message\n")
mediumMsg := []byte(strings.Repeat("medium message content ", 10) + "\n")
largeMsg := []byte(strings.Repeat("large message content for benchmarking ", 100) + "\n")
b.Run("SmallWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(smallMsg)
}
})
b.Run("MediumWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
}
})
b.Run("LargeWrite", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(largeMsg)
}
})
b.Run("WithSubscribers", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Add some subscribers
for i := 0; i < 5; i++ {
lm.OnLogData(func(data []byte) {})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.Write(mediumMsg)
}
})
b.Run("GetHistory", func(b *testing.B) {
lm := NewLogMonitorWriter(io.Discard)
// Pre-populate with data
for i := 0; i < 1000; i++ {
lm.Write(mediumMsg)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lm.GetHistory()
}
})
}
/*
Benchmark Results - MBP M1 Pro
Before (ring.Ring):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 43 ns | 40 B | 2 |
| MediumWrite (241B) | 76 ns | 264 B | 2 |
| LargeWrite (4KB) | 504 ns | 4,120 B | 2 |
| WithSubscribers (5 subs) | 355 ns | 264 B | 2 |
| GetHistory (after 1000 writes) | 145,000 ns | 1.2 MB | 22 |
After (circularBuffer 10KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 67 ns | 240 B | 1 |
| LargeWrite (4KB) | 774 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 325 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 1,042 ns | 10,240 B | 1 |
After (circularBuffer 100KB):
| Benchmark | ns/op | bytes/op | allocs/op |
|---------------------------------|------------|-----------|-----------|
| SmallWrite (14B) | 26 ns | 16 B | 1 |
| MediumWrite (241B) | 66 ns | 240 B | 1 |
| LargeWrite (4KB) | 753 ns | 4,096 B | 1 |
| WithSubscribers (5 subs) | 309 ns | 240 B | 1 |
| GetHistory (after 1000 writes) | 7,788 ns | 106,496 B | 1 |
Summary:
- GetHistory: 139x faster (10KB), 18x faster (100KB)
- Allocations: reduced from 2 to 1 across all operations
- Small/medium writes: ~1.1-1.6x faster
*/
+78 -10
View File
@@ -2,6 +2,8 @@ package proxy
import ( import (
"bytes" "bytes"
"compress/flate"
"compress/gzip"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -96,6 +98,12 @@ func (mp *metricsMonitor) wrapHandler(
next func(modelID string, w http.ResponseWriter, r *http.Request) error, next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error { ) error {
recorder := newBodyCopier(writer) recorder := newBodyCopier(writer)
// Filter Accept-Encoding to only include encodings we can decompress for metrics
if ae := request.Header.Get("Accept-Encoding"); ae != "" {
request.Header.Set("Accept-Encoding", filterAcceptEncoding(ae))
}
if err := next(modelID, recorder, request); err != nil { if err := next(modelID, recorder, request); err != nil {
return err return err
} }
@@ -108,17 +116,36 @@ func (mp *metricsMonitor) wrapHandler(
return nil return nil
} }
// Initialize default metrics - these will always be recorded
tm := TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
}
body := recorder.body.Bytes() body := recorder.body.Bytes()
if len(body) == 0 { if len(body) == 0 {
mp.logger.Warn("metrics skipped, empty body") mp.logger.Warn("metrics: empty body, recording minimal metrics")
mp.addMetrics(tm)
return nil return nil
} }
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") { // Decompress if needed
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { if encoding := recorder.Header().Get("Content-Encoding"); encoding != "" {
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path) var err error
} else { body, err = decompressBody(body, encoding)
if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, request.URL.Path)
mp.addMetrics(tm) mp.addMetrics(tm)
return nil
}
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsed
} }
} else { } else {
if gjson.ValidBytes(body) { if gjson.ValidBytes(body) {
@@ -127,18 +154,18 @@ func (mp *metricsMonitor) wrapHandler(
timings := parsed.Get("timings") timings := parsed.Get("timings")
if usage.Exists() || timings.Exists() { if usage.Exists() || timings.Exists() {
if tm, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil { if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path) mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else { } else {
mp.addMetrics(tm) tm = parsedMetrics
} }
} }
} else { } else {
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path) mp.logger.Warnf("metrics: invalid JSON in response body path=%s, recording minimal metrics", request.URL.Path)
} }
} }
mp.addMetrics(tm)
return nil return nil
} }
@@ -251,6 +278,25 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
}, nil }, nil
} }
// decompressBody decompresses the body based on Content-Encoding header
func decompressBody(body []byte, encoding string) ([]byte, error) {
switch strings.ToLower(strings.TrimSpace(encoding)) {
case "gzip":
reader, err := gzip.NewReader(bytes.NewReader(body))
if err != nil {
return nil, err
}
defer reader.Close()
return io.ReadAll(reader)
case "deflate":
reader := flate.NewReader(bytes.NewReader(body))
defer reader.Close()
return io.ReadAll(reader)
default:
return body, nil // Return as-is for unknown/no encoding
}
}
// responseBodyCopier records the response body and writes to the original response writer // responseBodyCopier records the response body and writes to the original response writer
// while also capturing it in a buffer for later processing // while also capturing it in a buffer for later processing
type responseBodyCopier struct { type responseBodyCopier struct {
@@ -289,3 +335,25 @@ func (w *responseBodyCopier) Header() http.Header {
func (w *responseBodyCopier) StartTime() time.Time { func (w *responseBodyCopier) StartTime() time.Time {
return w.start return w.start
} }
// filterAcceptEncoding filters the Accept-Encoding header to only include
// encodings we can decompress (gzip, deflate). This respects the client's
// preferences while ensuring we can parse response bodies for metrics.
func filterAcceptEncoding(acceptEncoding string) string {
if acceptEncoding == "" {
return ""
}
supported := map[string]bool{"gzip": true, "deflate": true}
var filtered []string
for _, part := range strings.Split(acceptEncoding, ",") {
// Parse encoding and optional quality value (e.g., "gzip;q=1.0")
encoding := strings.TrimSpace(strings.Split(part, ";")[0])
if supported[strings.ToLower(encoding)] {
filtered = append(filtered, strings.TrimSpace(part))
}
}
return strings.Join(filtered, ", ")
}
+154 -13
View File
@@ -1,6 +1,9 @@
package proxy package proxy
import ( import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -291,7 +294,7 @@ data: [DONE]
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 0, len(metrics))
}) })
t.Run("empty response body does not record metrics", func(t *testing.T) { t.Run("empty response body records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10) mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
@@ -307,10 +310,13 @@ data: [DONE]
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
}) })
t.Run("invalid JSON does not record metrics", func(t *testing.T) { t.Run("invalid JSON records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10) mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
@@ -328,7 +334,10 @@ data: [DONE]
assert.NoError(t, err) // Errors after response is sent are logged, not returned assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
}) })
t.Run("next handler error is propagated", func(t *testing.T) { t.Run("next handler error is propagated", func(t *testing.T) {
@@ -350,7 +359,7 @@ data: [DONE]
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 0, len(metrics))
}) })
t.Run("response without usage or timings does not record metrics", func(t *testing.T) { t.Run("response without usage or timings records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10) mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"result": "ok"}` responseBody := `{"result": "ok"}`
@@ -367,10 +376,13 @@ data: [DONE]
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
}) })
} }
@@ -598,7 +610,7 @@ data: [DONE]
assert.Equal(t, 50, metrics[0].OutputTokens) assert.Equal(t, 50, metrics[0].OutputTokens)
}) })
t.Run("handles streaming with no valid JSON", func(t *testing.T) { t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10) mm := newMetricsMonitor(testLogger, 10)
responseBody := `data: not json responseBody := `data: not json
@@ -619,13 +631,16 @@ data: [DONE]
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
}) })
t.Run("handles empty streaming response", func(t *testing.T) { t.Run("handles empty streaming response records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10) mm := newMetricsMonitor(testLogger, 10)
responseBody := `` responseBody := ``
@@ -642,11 +657,13 @@ data: [DONE]
ginCtx, _ := gin.CreateTestContext(rec) ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
// Empty body should not trigger WrapHandler processing
assert.NoError(t, err) assert.NoError(t, err)
metrics := mm.getMetrics() metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics)) assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
}) })
} }
@@ -691,3 +708,127 @@ func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
mm.addMetrics(metric) mm.addMetrics(metric)
} }
} }
func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
t.Run("gzip encoded response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"usage": {"prompt_tokens": 100, "completion_tokens": 50}}`
// Compress with gzip
var buf bytes.Buffer
gzWriter := gzip.NewWriter(&buf)
gzWriter.Write([]byte(responseBody))
gzWriter.Close()
compressedBody := buf.Bytes()
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusOK)
w.Write(compressedBody)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("deflate encoded response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"usage": {"prompt_tokens": 200, "completion_tokens": 75}}`
// Compress with deflate
var buf bytes.Buffer
flateWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression)
flateWriter.Write([]byte(responseBody))
flateWriter.Close()
compressedBody := buf.Bytes()
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "deflate")
w.WriteHeader(http.StatusOK)
w.Write(compressedBody)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 200, metrics[0].InputTokens)
assert.Equal(t, 75, metrics[0].OutputTokens)
})
t.Run("invalid gzip data records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Invalid compressed data
invalidData := []byte("this is not gzip data")
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusOK)
w.Write(invalidData)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Should not return error, just log warning
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 0, metrics[0].InputTokens)
assert.Equal(t, 0, metrics[0].OutputTokens)
})
t.Run("unknown encoding treated as uncompressed", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"usage": {"prompt_tokens": 300, "completion_tokens": 100}}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "unknown-encoding")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 300, metrics[0].InputTokens)
assert.Equal(t, 100, metrics[0].OutputTokens)
})
}
+127
View File
@@ -0,0 +1,127 @@
package proxy
import (
"fmt"
"net"
"net/http"
"net/http/httputil"
"runtime"
"sort"
"strings"
"time"
"github.com/mostlygeek/llama-swap/proxy/config"
)
type peerProxyMember struct {
peerID string
reverseProxy *httputil.ReverseProxy
apiKey string
}
type PeerProxy struct {
peers config.PeerDictionaryConfig
proxyMap map[string]*peerProxyMember
}
func NewPeerProxy(peers config.PeerDictionaryConfig, proxyLogger *LogMonitor) (*PeerProxy, error) {
proxyMap := make(map[string]*peerProxyMember)
// Sort peer IDs for consistent iteration order
peerIDs := make([]string, 0, len(peers))
for peerID := range peers {
peerIDs = append(peerIDs, peerID)
}
sort.Strings(peerIDs)
// Create a shared transport with reasonable timeouts for peer connections
// these can be tuned with feedback later
peerTransport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, // Connection timeout
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 60 * time.Second, // Time to wait for response headers
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
}
for _, peerID := range peerIDs {
peer := peers[peerID]
// Create reverse proxy for this peer
reverseProxy := httputil.NewSingleHostReverseProxy(peer.ProxyURL)
reverseProxy.Transport = peerTransport
// Wrap Director to set Host header for remote hosts (not localhost)
originalDirector := reverseProxy.Director
reverseProxy.Director = func(req *http.Request) {
originalDirector(req)
// Ensure Host header matches target URL for remote proxying
req.Host = req.URL.Host
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
proxyLogger.Warnf("peer %s: proxy error: %v", peerID, err)
errMsg := fmt.Sprintf("peer proxy error: %v", err)
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
}
http.Error(w, errMsg, http.StatusBadGateway)
}
pp := &peerProxyMember{
peerID: peerID,
reverseProxy: reverseProxy,
apiKey: peer.ApiKey,
}
// Map each model to this peer's proxy
for _, modelID := range peer.Models {
if _, found := proxyMap[modelID]; found {
proxyLogger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
continue
}
proxyMap[modelID] = pp
}
}
return &PeerProxy{
peers: peers,
proxyMap: proxyMap,
}, nil
}
func (p *PeerProxy) HasPeerModel(modelID string) bool {
_, found := p.proxyMap[modelID]
return found
}
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
return p.peers
}
func (p *PeerProxy) ProxyRequest(model_id string, writer http.ResponseWriter, request *http.Request) error {
pp, found := p.proxyMap[model_id]
if !found {
return fmt.Errorf("no peer proxy found for model %s", model_id)
}
// Inject API key if configured for this peer
if pp.apiKey != "" {
request.Header.Set("Authorization", "Bearer "+pp.apiKey)
request.Header.Set("x-api-key", pp.apiKey)
}
pp.reverseProxy.ServeHTTP(writer, request)
return nil
}
+268
View File
@@ -0,0 +1,268 @@
package proxy
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewPeerProxy_EmptyPeers(t *testing.T) {
peers := config.PeerDictionaryConfig{}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.NotNil(t, pm)
assert.Empty(t, pm.proxyMap)
}
func TestNewPeerProxy_SinglePeer(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
ApiKey: "test-key",
Models: []string{"model-a", "model-b"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.Len(t, pm.proxyMap, 2)
assert.True(t, pm.HasPeerModel("model-a"))
assert.True(t, pm.HasPeerModel("model-b"))
assert.False(t, pm.HasPeerModel("model-c"))
}
func TestNewPeerProxy_MultiplePeers(t *testing.T) {
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"model-a", "model-b"},
},
"peer2": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"model-c", "model-d"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.Len(t, pm.proxyMap, 4)
assert.True(t, pm.HasPeerModel("model-a"))
assert.True(t, pm.HasPeerModel("model-b"))
assert.True(t, pm.HasPeerModel("model-c"))
assert.True(t, pm.HasPeerModel("model-d"))
}
func TestNewPeerProxy_DuplicateModelWarning(t *testing.T) {
// When the same model is in multiple peers, only the first (lexicographically by peer ID)
// should be mapped, and a warning should be logged
proxyURL1, _ := url.Parse("http://peer1.example.com:8080")
proxyURL2, _ := url.Parse("http://peer2.example.com:8080")
peers := config.PeerDictionaryConfig{
"alpha-peer": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL1,
Models: []string{"duplicate-model"},
},
"beta-peer": config.PeerConfig{
Proxy: "http://peer2.example.com:8080",
ProxyURL: proxyURL2,
Models: []string{"duplicate-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
// Should only have one entry for the duplicate model
assert.Len(t, pm.proxyMap, 1)
assert.True(t, pm.HasPeerModel("duplicate-model"))
}
func TestHasPeerModel(t *testing.T) {
proxyURL, _ := url.Parse("http://peer1.example.com:8080")
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: "http://peer1.example.com:8080",
ProxyURL: proxyURL,
Models: []string{"existing-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
assert.True(t, pm.HasPeerModel("existing-model"))
assert.False(t, pm.HasPeerModel("non-existing-model"))
}
func TestProxyRequest_ModelNotFound(t *testing.T) {
peers := config.PeerDictionaryConfig{}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("non-existing-model", w, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer proxy found for model non-existing-model")
}
func TestProxyRequest_Success(t *testing.T) {
// Create a test server to act as the peer
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("response from peer"))
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "response from peer", w.Body.String())
}
func TestProxyRequest_ApiKeyInjection(t *testing.T) {
// Create a test server that checks for the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "secret-api-key",
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Equal(t, "Bearer secret-api-key", receivedAuthHeader)
}
func TestProxyRequest_NoApiKey(t *testing.T) {
// Create a test server that checks for the Authorization header
var receivedAuthHeader string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
ApiKey: "", // No API key
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
assert.Empty(t, receivedAuthHeader)
}
func TestProxyRequest_HostHeaderSet(t *testing.T) {
// Create a test server that checks the Host header
var receivedHost string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHost = r.Host
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
// The Host header should be set to the target URL's host
assert.True(t, strings.HasPrefix(receivedHost, "127.0.0.1:"))
}
func TestProxyRequest_SSEHeaderModification(t *testing.T) {
// Create a test server that returns SSE content type
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
proxyURL, _ := url.Parse(testServer.URL)
peers := config.PeerDictionaryConfig{
"peer1": config.PeerConfig{
Proxy: testServer.URL,
ProxyURL: proxyURL,
Models: []string{"test-model"},
},
}
pm, err := NewPeerProxy(peers, testLogger)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()
err = pm.ProxyRequest("test-model", w, req)
assert.NoError(t, err)
// The X-Accel-Buffering header should be set to "no" for SSE
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
}
+8 -1
View File
@@ -414,6 +414,9 @@ func (p *Process) stopCommand() {
stopStartTime := time.Now() stopStartTime := time.Now()
defer func() { defer func() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime)) p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
// free the buffer in processLogger so the memory can be recovered
p.processLogger.Clear()
}() }()
p.cmdMutex.RLock() p.cmdMutex.RLock()
@@ -646,6 +649,11 @@ func (p *Process) cmdStopUpstreamProcess() error {
return nil return nil
} }
// Logger returns the logger for this process.
func (p *Process) Logger() *LogMonitor {
return p.processLogger
}
var loadingRemarks = []string{ var loadingRemarks = []string{
"Still faster than your last standup meeting...", "Still faster than your last standup meeting...",
"Reticulating splines...", "Reticulating splines...",
@@ -864,7 +872,6 @@ func (s *statusResponseWriter) WriteHeader(statusCode int) {
s.Flush() s.Flush()
} }
// Add Flush method
func (s *statusResponseWriter) Flush() { func (s *statusResponseWriter) Flush() {
if flusher, ok := s.writer.(http.Flusher); ok { if flusher, ok := s.writer.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
+4
View File
@@ -395,6 +395,10 @@ func TestProcess_StopImmediately(t *testing.T) {
// Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates // Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates
// the upstream command // the upstream command
func TestProcess_ForceStopWithKill(t *testing.T) { func TestProcess_ForceStopWithKill(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("skipping SIGTERM test on Windows ") t.Skip("skipping SIGTERM test on Windows ")
} }
+9 -1
View File
@@ -46,7 +46,8 @@ func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, u
// Create a Process for each member in the group // Create a Process for each member in the group
for _, modelID := range groupConfig.Members { for _, modelID := range groupConfig.Members {
modelConfig, modelID, _ := pg.config.FindConfig(modelID) modelConfig, modelID, _ := pg.config.FindConfig(modelID)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger) processLogger := NewLogMonitorWriter(upstreamLogger)
process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger)
pg.processes[modelID] = process pg.processes[modelID] = process
} }
@@ -88,6 +89,13 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName) return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
} }
func (pg *ProcessGroup) GetMember(modelName string) (*Process, bool) {
if pg.HasMember(modelName) {
return pg.processes[modelName], true
}
return nil, false
}
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error { func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
pg.Lock() pg.Lock()
+4
View File
@@ -49,6 +49,10 @@ func TestProcessGroup_HasMember(t *testing.T) {
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true // TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
// and multiple requests are made in parallel, only one process is running at a time. // and multiple requests are made in parallel, only one process is running at a time.
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{ Models: map[string]config.ModelConfig{
+286 -152
View File
@@ -3,6 +3,7 @@ package proxy
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
@@ -50,19 +51,42 @@ type ProxyManager struct {
buildDate string buildDate string
commit string commit string
version string version string
// peer proxy see: #296, #433
peerProxy *PeerProxy
} }
func New(config config.Config) *ProxyManager { func New(proxyConfig config.Config) *ProxyManager {
// set up loggers // set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests { var muxLogger, upstreamLogger, proxyLogger *LogMonitor
switch proxyConfig.LogToStdout {
case config.LogToStdoutNone:
muxLogger = NewLogMonitorWriter(io.Discard)
upstreamLogger = NewLogMonitorWriter(io.Discard)
proxyLogger = NewLogMonitorWriter(io.Discard)
case config.LogToStdoutBoth:
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(muxLogger)
proxyLogger = NewLogMonitorWriter(muxLogger)
case config.LogToStdoutUpstream:
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(muxLogger)
proxyLogger = NewLogMonitorWriter(io.Discard)
default:
// same as config.LogToStdoutProxy
// helpful because some old tests create a config.Config directly and it
// may not have LogToStdout set explicitly
muxLogger = NewLogMonitorWriter(os.Stdout)
upstreamLogger = NewLogMonitorWriter(io.Discard)
proxyLogger = NewLogMonitorWriter(muxLogger)
}
if proxyConfig.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.") proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
} }
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) { switch strings.ToLower(strings.TrimSpace(proxyConfig.LogLevel)) {
case "debug": case "debug":
proxyLogger.SetLogLevel(LevelDebug) proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug) upstreamLogger.SetLogLevel(LevelDebug)
@@ -99,7 +123,7 @@ func New(config config.Config) *ProxyManager {
"stampnano": time.StampNano, "stampnano": time.StampNano,
} }
if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(config.LogTimeFormat))]; ok { if timeFormat, ok := timeFormats[strings.ToLower(strings.TrimSpace(proxyConfig.LogTimeFormat))]; ok {
proxyLogger.SetLogTimeFormat(timeFormat) proxyLogger.SetLogTimeFormat(timeFormat)
upstreamLogger.SetLogTimeFormat(timeFormat) upstreamLogger.SetLogTimeFormat(timeFormat)
} }
@@ -107,18 +131,24 @@ func New(config config.Config) *ProxyManager {
shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
var maxMetrics int var maxMetrics int
if config.MetricsMaxInMemory <= 0 { if proxyConfig.MetricsMaxInMemory <= 0 {
maxMetrics = 1000 // Default fallback maxMetrics = 1000 // Default fallback
} else { } else {
maxMetrics = config.MetricsMaxInMemory maxMetrics = proxyConfig.MetricsMaxInMemory
}
peerProxy, err := NewPeerProxy(proxyConfig.Peers, proxyLogger)
if err != nil {
proxyLogger.Errorf("Disabling Peering. Failed to create proxy peers: %v", err)
peerProxy = nil
} }
pm := &ProxyManager{ pm := &ProxyManager{
config: config, config: proxyConfig,
ginEngine: gin.New(), ginEngine: gin.New(),
proxyLogger: proxyLogger, proxyLogger: proxyLogger,
muxLogger: stdoutLogger, muxLogger: muxLogger,
upstreamLogger: upstreamLogger, upstreamLogger: upstreamLogger,
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics), metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
@@ -131,37 +161,46 @@ func New(config config.Config) *ProxyManager {
buildDate: "unknown", buildDate: "unknown",
commit: "abcd1234", commit: "abcd1234",
version: "0", version: "0",
peerProxy: peerProxy,
} }
// create the process groups // create the process groups
for groupID := range config.Groups { for groupID := range proxyConfig.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger) processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup pm.processGroups[groupID] = processGroup
} }
pm.setupGinEngine() pm.setupGinEngine()
// run any startup hooks // run any startup hooks
if len(config.Hooks.OnStartup.Preload) > 0 { if len(proxyConfig.Hooks.OnStartup.Preload) > 0 {
// do it in the background, don't block startup -- not sure if good idea yet // do it in the background, don't block startup -- not sure if good idea yet
go func() { go func() {
discardWriter := &DiscardWriter{} discardWriter := &DiscardWriter{}
for _, realModelName := range config.Hooks.OnStartup.Preload { for _, preloadModelName := range proxyConfig.Hooks.OnStartup.Preload {
proxyLogger.Infof("Preloading model: %s", realModelName) modelID, ok := proxyConfig.RealModelName(preloadModelName)
processGroup, _, err := pm.swapProcessGroup(realModelName)
if !ok {
proxyLogger.Warnf("Preload model %s not found in config", preloadModelName)
continue
}
proxyLogger.Infof("Preloading model: %s", modelID)
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil { if err != nil {
event.Emit(ModelPreloadedEvent{ event.Emit(ModelPreloadedEvent{
ModelName: realModelName, ModelName: modelID,
Success: false, Success: false,
}) })
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err) proxyLogger.Errorf("Failed to preload model %s: %v", modelID, err)
continue continue
} else { } else {
req, _ := http.NewRequest("GET", "/", nil) req, _ := http.NewRequest("GET", "/", nil)
processGroup.ProxyRequest(realModelName, discardWriter, req) processGroup.ProxyRequest(modelID, discardWriter, req)
event.Emit(ModelPreloadedEvent{ event.Emit(ModelPreloadedEvent{
ModelName: realModelName, ModelName: modelID,
Success: true, Success: true,
}) })
} }
@@ -236,37 +275,42 @@ func (pm *ProxyManager) setupGinEngine() {
}) })
// Set up routes using the Gin engine // Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyInferenceHandler) // Protected routes use pm.apiKeyAuth() middleware
pm.ginEngine.POST("/v1/chat/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/responses", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support legacy /v1/completions api, see issue #12 // Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/completions", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570) // Support anthropic /v1/messages (added https://github.com/ggml-org/llama.cpp/pull/17570)
pm.ginEngine.POST("/v1/messages", pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/messages", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support embeddings and reranking // Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/embeddings", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// llama-server's /reranking endpoint + aliases // llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", pm.proxyInferenceHandler) pm.ginEngine.POST("/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/rerank", pm.proxyInferenceHandler) pm.ginEngine.POST("/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/rerank", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/reranking", pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/reranking", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// llama-server's /infill endpoint for code infilling // llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", pm.proxyInferenceHandler) pm.ginEngine.POST("/infill", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// llama-server's /completion endpoint // llama-server's /completion endpoint
pm.ginEngine.POST("/completion", pm.proxyInferenceHandler) pm.ginEngine.POST("/completion", pm.apiKeyAuth(), pm.proxyInferenceHandler)
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyInferenceHandler) pm.ginEngine.POST("/v1/audio/speech", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler) pm.ginEngine.POST("/v1/audio/voices", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
pm.ginEngine.POST("/v1/images/generations", pm.apiKeyAuth(), pm.proxyInferenceHandler)
pm.ginEngine.POST("/v1/images/edits", pm.apiKeyAuth(), pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler) pm.ginEngine.GET("/v1/models", pm.apiKeyAuth(), pm.listModelsHandler)
// in proxymanager_loghandlers.go // in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs", pm.apiKeyAuth(), pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.apiKeyAuth(), pm.streamLogsHandler)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream/*logMonitorID", pm.apiKeyAuth(), pm.streamLogsHandler)
/** /**
* User Interface Endpoints * User Interface Endpoints
@@ -278,9 +322,9 @@ func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.GET("/upstream", func(c *gin.Context) { pm.ginEngine.GET("/upstream", func(c *gin.Context) {
c.Redirect(http.StatusFound, "/ui/models") c.Redirect(http.StatusFound, "/ui/models")
}) })
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.Any("/upstream/*upstreamPath", pm.apiKeyAuth(), pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/unload", pm.apiKeyAuth(), pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) pm.ginEngine.GET("/running", pm.apiKeyAuth(), pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) { pm.ginEngine.GET("/health", func(c *gin.Context) {
c.String(http.StatusOK, "OK") c.String(http.StatusOK, "OK")
}) })
@@ -378,16 +422,10 @@ func (pm *ProxyManager) Shutdown() {
pm.shutdownCancel() pm.shutdownCancel()
} }
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) { func (pm *ProxyManager) swapProcessGroup(realModelName string) (*ProcessGroup, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
processGroup := pm.findGroupByModelName(realModelName) processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil { if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel) return nil, fmt.Errorf("could not find process group for model %s", realModelName)
} }
if processGroup.exclusive { if processGroup.exclusive {
@@ -399,54 +437,71 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
} }
} }
return processGroup, realModelName, nil return processGroup, nil
} }
func (pm *ProxyManager) listModelsHandler(c *gin.Context) { func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := make([]gin.H, 0, len(pm.config.Models)) data := make([]gin.H, 0, len(pm.config.Models))
createdTime := time.Now().Unix() createdTime := time.Now().Unix()
newRecord := func(modelId string, modelConfig config.ModelConfig) gin.H {
record := gin.H{
"id": modelId,
"object": "model",
"created": createdTime,
"owned_by": "llama-swap",
}
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
// Add metadata if present
if len(modelConfig.Metadata) > 0 {
record["meta"] = gin.H{
"llamaswap": modelConfig.Metadata,
}
}
return record
}
for id, modelConfig := range pm.config.Models { for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted { if modelConfig.Unlisted {
continue continue
} }
newRecord := func(modelId string) gin.H { data = append(data, newRecord(id, modelConfig))
record := gin.H{
"id": modelId,
"object": "model",
"created": createdTime,
"owned_by": "llama-swap",
}
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
// Add metadata if present
if len(modelConfig.Metadata) > 0 {
record["meta"] = gin.H{
"llamaswap": modelConfig.Metadata,
}
}
return record
}
data = append(data, newRecord(id))
// Include aliases // Include aliases
if pm.config.IncludeAliasesInList { if pm.config.IncludeAliasesInList {
for _, alias := range modelConfig.Aliases { for _, alias := range modelConfig.Aliases {
if alias := strings.TrimSpace(alias); alias != "" { if alias := strings.TrimSpace(alias); alias != "" {
data = append(data, newRecord(alias)) data = append(data, newRecord(alias, modelConfig))
} }
} }
} }
} }
if pm.peerProxy != nil {
for peerID, peer := range pm.peerProxy.ListPeers() {
// add peer models
for _, modelID := range peer.Models {
// Skip unlisted models if not showing them
record := newRecord(modelID, config.ModelConfig{
Name: fmt.Sprintf("%s: %s", peerID, modelID),
Metadata: map[string]any{
"peerID": peerID,
},
})
data = append(data, record)
}
}
}
// Sort by the "id" key // Sort by the "id" key
sort.Slice(data, func(i, j int) bool { sort.Slice(data, func(i, j int) bool {
si, _ := data[i]["id"].(string) si, _ := data[i]["id"].(string)
@@ -466,62 +521,61 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
}) })
} }
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { // findModelInPath searches for a valid model name in a path with slashes.
upstreamPath := c.Param("upstreamPath") // It iteratively builds up path segments until it finds a matching model.
// Returns: (searchModelName, realModelName, remainingPath, found)
// split the upstream path by / and search for the model name // Example: "/author/model/endpoint" with model "author/model" -> ("author/model", "author/model", "/endpoint", true)
parts := strings.Split(strings.TrimSpace(upstreamPath), "/") func (pm *ProxyManager) findModelInPath(path string) (searchName string, realName string, remainingPath string, found bool) {
if len(parts) == 0 { parts := strings.Split(strings.TrimSpace(path), "/")
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
modelFound := false
searchModelName := "" searchModelName := ""
var modelName, remainingPath string
for i, part := range parts { for i, part := range parts {
if parts[i] == "" { if part == "" {
continue continue
} }
if searchModelName == "" { if searchModelName == "" {
searchModelName = part searchModelName = part
} else { } else {
searchModelName = searchModelName + "/" + parts[i] searchModelName = searchModelName + "/" + part
} }
if real, ok := pm.config.RealModelName(searchModelName); ok { if modelID, ok := pm.config.RealModelName(searchModelName); ok {
modelName = real return searchModelName, modelID, "/" + strings.Join(parts[i+1:], "/"), true
remainingPath = "/" + strings.Join(parts[i+1:], "/")
modelFound = true
// Check if this is exactly a model name with no additional path
// and doesn't end with a trailing slash
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
// Build new URL with query parameters preserved
newPath := "/upstream/" + searchModelName + "/"
if c.Request.URL.RawQuery != "" {
newPath += "?" + c.Request.URL.RawQuery
}
// Use 308 for non-GET/HEAD requests to preserve method
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
c.Redirect(http.StatusMovedPermanently, newPath)
} else {
c.Redirect(http.StatusPermanentRedirect, newPath)
}
return
}
break
} }
} }
return "", "", "", false
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
upstreamPath := c.Param("upstreamPath")
searchModelName, modelID, remainingPath, modelFound := pm.findModelInPath(upstreamPath)
if !modelFound { if !modelFound {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return return
} }
processGroup, realModelName, err := pm.swapProcessGroup(modelName) // Redirect /upstream/modelname to /upstream/modelname/ for URL consistency.
// This ensures relative URLs in upstream responses resolve correctly and
// provides canonical URL form. Uses 308 for POST/PUT/etc to preserve the
// HTTP method (301 would downgrade to GET).
if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") {
newPath := "/upstream/" + searchModelName + "/"
if c.Request.URL.RawQuery != "" {
newPath += "?" + c.Request.URL.RawQuery
}
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
c.Redirect(http.StatusMovedPermanently, newPath)
} else {
c.Redirect(http.StatusPermanentRedirect, newPath)
}
return
}
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
@@ -533,15 +587,15 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// attempt to record metrics if it is a POST request // attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" { if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath) pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", modelID, originalPath)
return return
} }
} else { } else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { if err := processGroup.ProxyRequest(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath) pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", modelID, originalPath)
return return
} }
} }
@@ -560,41 +614,54 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
return return
} }
realModelName, found := pm.config.RealModelName(requestedModel) // Look for a matching local model first
if !found { var nextHandler func(modelID string, w http.ResponseWriter, r *http.Request) error
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}
processGroup, _, err := pm.swapProcessGroup(realModelName) modelID, found := pm.config.RealModelName(requestedModel)
if err != nil { if found {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) processGroup, err := pm.swapProcessGroup(modelID)
return
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
} }
}
// issue #174 strip parameters from the JSON body // issue #69 allow custom model names to be sent to upstream
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams() useModelName := pm.config.Models[modelID].UseModelName
if err != nil { // just log it and continue if useModelName != "" {
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error()) bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return return
} }
} }
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[modelID].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[modelID].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", modelID, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find suitable inference handler for %s", requestedModel))
return
} }
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -607,19 +674,19 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
// issue #366 extract values that downstream handlers may need // issue #366 extract values that downstream handlers may need
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool() isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming) ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName) ctx = context.WithValue(ctx, proxyCtxKey("model"), modelID)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
if pm.metricsMonitor != nil && c.Request.Method == "POST" { if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { if err := pm.metricsMonitor.wrapHandler(modelID, c.Writer, c.Request, nextHandler); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName) pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request model %s", modelID)
return return
} }
} else { } else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { if err := nextHandler(modelID, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) pm.proxyLogger.Errorf("Error Proxying Request for model %s", modelID)
return return
} }
} }
@@ -639,7 +706,13 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return return
} }
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) modelID, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}
processGroup, err := pm.swapProcessGroup(modelID)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
@@ -657,7 +730,7 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// If this is the model field and we have a profile, use just the model name // If this is the model field and we have a profile, use just the model name
if key == "model" { if key == "model" {
// # issue #69 allow custom model names to be sent to upstream // # issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName useModelName := pm.config.Models[modelID].UseModelName
if useModelName != "" { if useModelName != "" {
fieldValue = useModelName fieldValue = useModelName
@@ -728,9 +801,9 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.ContentLength = int64(requestBuffer.Len()) modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying // Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { if err := processGroup.ProxyRequest(modelID, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, modelID)
return return
} }
} }
@@ -745,6 +818,67 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag
} }
} }
// apiKeyAuth returns a middleware that validates API keys if configured.
// Returns a pass-through handler if no API keys are configured.
func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
if len(pm.config.RequiredAPIKeys) == 0 {
return func(c *gin.Context) { c.Next() }
}
return func(c *gin.Context) {
xApiKey := c.GetHeader("x-api-key")
var bearerKey string
var basicKey string
if auth := c.GetHeader("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
bearerKey = strings.TrimPrefix(auth, "Bearer ")
} else if strings.HasPrefix(auth, "Basic ") {
// Basic Auth: base64(username:password), password is the API key
encoded := strings.TrimPrefix(auth, "Basic ")
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) == 2 {
basicKey = parts[1] // password is the API key
}
}
}
}
// Use first key found: Basic, then Bearer, then x-api-key
var providedKey string
if basicKey != "" {
providedKey = basicKey
} else if bearerKey != "" {
providedKey = bearerKey
} else {
providedKey = xApiKey
}
// Validate key
valid := false
for _, key := range pm.config.RequiredAPIKeys {
if providedKey == key {
valid = true
break
}
}
if !valid {
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
c.Abort()
return
}
// Strip auth headers to prevent leakage to upstream
c.Request.Header.Del("Authorization")
c.Request.Header.Del("x-api-key")
c.Next()
}
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses(StopImmediately) pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK") c.String(http.StatusOK, "OK")
+15 -1
View File
@@ -18,11 +18,13 @@ type Model struct {
Description string `json:"description"` Description string `json:"description"`
State string `json:"state"` State string `json:"state"`
Unlisted bool `json:"unlisted"` Unlisted bool `json:"unlisted"`
PeerID string `json:"peerID"`
} }
func addApiHandlers(pm *ProxyManager) { func addApiHandlers(pm *ProxyManager) {
// Add API endpoints for React to consume // Add API endpoints for React to consume
apiGroup := pm.ginEngine.Group("/api") // Protected with API key authentication
apiGroup := pm.ginEngine.Group("/api", pm.apiKeyAuth())
{ {
apiGroup.POST("/models/unload", pm.apiUnloadAllModels) apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler) apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
@@ -82,6 +84,18 @@ func (pm *ProxyManager) getModelStatus() []Model {
}) })
} }
// Iterate over the peer models
if pm.peerProxy != nil {
for peerID, peer := range pm.peerProxy.ListPeers() {
for _, modelID := range peer.Models {
models = append(models, Model{
Id: modelID,
PeerID: peerID,
})
}
}
}
return models return models
} }
+20 -13
View File
@@ -31,7 +31,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// prevent nginx from buffering streamed logs // prevent nginx from buffering streamed logs
c.Header("X-Accel-Buffering", "no") c.Header("X-Accel-Buffering", "no")
logMonitorId := c.Param("logMonitorID") logMonitorId := strings.TrimPrefix(c.Param("logMonitorID"), "/")
logger, err := pm.getLogger(logMonitorId) logger, err := pm.getLogger(logMonitorId)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, err.Error()) c.String(http.StatusBadRequest, err.Error())
@@ -83,18 +83,25 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
// getLogger searches for the appropriate logger based on the logMonitorId // getLogger searches for the appropriate logger based on the logMonitorId
func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) { func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) {
var logger *LogMonitor switch logMonitorId {
case "":
if logMonitorId == "" {
// maintain the default // maintain the default
logger = pm.muxLogger return pm.muxLogger, nil
} else if logMonitorId == "proxy" { case "proxy":
logger = pm.proxyLogger return pm.proxyLogger, nil
} else if logMonitorId == "upstream" { case "upstream":
logger = pm.upstreamLogger return pm.upstreamLogger, nil
} else { default:
return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'") // search for a models specific logger using findModelInPath
} // to handle model names with slashes (e.g., "author/model")
if _, name, _, found := pm.findModelInPath("/" + logMonitorId); found {
for _, group := range pm.processGroups {
if process, found := group.GetMember(name); found {
return process.Logger(), nil
}
}
}
return logger, nil return nil, fmt.Errorf("invalid logger. Use 'proxy', 'upstream' or a model's ID")
}
} }
+382 -12
View File
@@ -3,6 +3,7 @@ package proxy
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand" "math/rand"
@@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel return r.closeChannel
} }
func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}
func CreateTestResponseRecorder() *TestResponseRecorder { func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{ return &TestResponseRecorder{
httptest.NewRecorder(), httptest.NewRecorder(),
@@ -223,17 +220,23 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
model2Config.Name = " " // empty whitespace only strings will get ignored model2Config.Name = " " // empty whitespace only strings will get ignored
model2Config.Description = " " model2Config.Description = " "
config := config.Config{ cfg := config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{ Models: map[string]config.ModelConfig{
"model1": model1Config, "model1": model1Config,
"model2": model2Config, "model2": model2Config,
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
Peers: map[string]config.PeerConfig{
"peer1": {
Proxy: "http://peer1:8080",
Models: []string{"peer-model-a", "peer-model-b"},
},
},
LogLevel: "error", LogLevel: "error",
} }
proxy := New(config) proxy := New(cfg)
// Create a test request // Create a test request
req := httptest.NewRequest("GET", "/v1/models", nil) req := httptest.NewRequest("GET", "/v1/models", nil)
@@ -258,14 +261,16 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
t.Fatalf("Failed to parse JSON response: %v", err) t.Fatalf("Failed to parse JSON response: %v", err)
} }
// Check the number of models returned // Check the number of models returned (3 local + 2 peer models)
assert.Len(t, response.Data, 3) assert.Len(t, response.Data, 5)
// Check the details of each model // Check the details of each model
expectedModels := map[string]struct{}{ expectedModels := map[string]struct{}{
"model1": {}, "model1": {},
"model2": {}, "model2": {},
"model3": {}, "model3": {},
"peer-model-a": {},
"peer-model-b": {},
} }
// make all models // make all models
@@ -296,6 +301,19 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
description, ok := model["description"].(string) description, ok := model["description"].(string)
assert.True(t, ok, "description should be a string") assert.True(t, ok, "description should be a string")
assert.Equal(t, "Model 1 description is used for testing", description) assert.Equal(t, "Model 1 description is used for testing", description)
} else if modelID == "peer-model-a" || modelID == "peer-model-b" {
// Peer models should have meta.llamaswap.peerID
meta, exists := model["meta"]
assert.True(t, exists, "peer model should have meta field")
metaMap, ok := meta.(map[string]interface{})
assert.True(t, ok, "meta should be a map")
llamaswap, exists := metaMap["llamaswap"]
assert.True(t, exists, "meta should have llamaswap field")
llamaswapMap, ok := llamaswap.(map[string]interface{})
assert.True(t, ok, "llamaswap should be a map")
peerID, exists := llamaswapMap["peerID"]
assert.True(t, exists, "llamaswap should have peerID field")
assert.Equal(t, "peer1", peerID)
} else { } else {
_, exists := model["name"] _, exists := model["name"]
assert.False(t, exists, "unexpected name field for model: %s", modelID) assert.False(t, exists, "unexpected name field for model: %s", modelID)
@@ -502,6 +520,10 @@ func TestProxyManager_ListModelsHandler_IncludeAliasesInList(t *testing.T) {
} }
func TestProxyManager_Shutdown(t *testing.T) { func TestProxyManager_Shutdown(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
// make broken model configurations // make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991) model1Config := getTestSimpleResponderConfigPort("model1", 9991)
model1Config.Proxy = "http://localhost:10001/" model1Config.Proxy = "http://localhost:10001/"
@@ -1078,7 +1100,8 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"author/model": getTestSimpleResponderConfig("author/model"),
}, },
LogLevel: "error", LogLevel: "error",
}) })
@@ -1091,6 +1114,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
"/logs/stream", "/logs/stream",
"/logs/stream/proxy", "/logs/stream/proxy",
"/logs/stream/upstream", "/logs/stream/upstream",
"/logs/stream/author/model",
} }
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
@@ -1185,3 +1209,349 @@ func TestProxyManager_ApiGetVersion(t *testing.T) {
assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key]) assert.Equal(t, value, response[key], "%s value %s should match response %s", key, value, response[key])
} }
} }
func TestProxyManager_APIKeyAuth(t *testing.T) {
testConfig := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
RequiredAPIKeys: []string{"valid-key-1", "valid-key-2"},
LogLevel: "error",
})
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
t.Run("valid key in x-api-key header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("valid key in Authorization Bearer header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("Authorization", "Bearer valid-key-2")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("both headers with matching keys", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
req.Header.Set("Authorization", "Bearer valid-key-1")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("invalid key returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "invalid-key")
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "unauthorized")
})
t.Run("missing key returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
})
t.Run("valid key in Basic Auth header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
// Basic Auth: base64("anyuser:valid-key-1")
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "unauthorized")
})
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
})
}
func TestProxyManager_APIKeyAuth_Disabled(t *testing.T) {
// Config without RequiredAPIKeys - auth should be disabled
testConfig := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
t.Run("requests pass without API key when not configured", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})
}
// TestProxyManager_PeerProxy_InferenceHandler tests the peerProxy integration
// in proxyInferenceHandler for issue #433
func TestProxyManager_PeerProxy_InferenceHandler(t *testing.T) {
t.Run("requests to peer models are proxied", func(t *testing.T) {
// Create a test server to act as the peer
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"response":"from-peer","model":"peer-model"}`))
}))
defer peerServer.Close()
// Create config with peers but no local model for "peer-model"
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
models:
- peer-model
models:
local-model:
cmd: %s -port ${PORT} -silent -respond local-model
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"peer-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "from-peer")
})
t.Run("local models take precedence over peer models", func(t *testing.T) {
// Create a test server to act as the peer - should NOT be called
peerCalled := false
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
peerCalled = true
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"response":"from-peer"}`))
}))
defer peerServer.Close()
// Create config where "shared-model" exists both locally and on peer
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
models:
- shared-model
models:
shared-model:
cmd: %s -port ${PORT} -silent -respond local-response
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"shared-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "local-response")
assert.False(t, peerCalled, "peer should not be called when local model exists")
})
t.Run("unknown model returns error", func(t *testing.T) {
// Create a test server to act as the peer
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer peerServer.Close()
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
models:
- peer-model
models:
local-model:
cmd: %s -port ${PORT} -silent -respond local-model
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"unknown-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
})
t.Run("peer API key is injected into request", func(t *testing.T) {
var receivedAuthHeader string
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthHeader = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"response":"ok"}`))
}))
defer peerServer.Close()
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
apiKey: secret-peer-key
models:
- peer-model
models:
local-model:
cmd: %s -port ${PORT} -silent -respond local-model
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"peer-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Bearer secret-peer-key", receivedAuthHeader)
})
t.Run("no peers configured - unknown model returns error", func(t *testing.T) {
testConfig := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"local-model": getTestSimpleResponderConfig("local-model"),
},
LogLevel: "error",
})
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
// peerProxy exists but has no peer models configured
assert.False(t, proxy.peerProxy.HasPeerModel("unknown-model"))
reqBody := `{"model":"unknown-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "could not find suitable inference handler")
})
t.Run("peer streaming response sets X-Accel-Buffering header", func(t *testing.T) {
peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte("data: test\n\n"))
}))
defer peerServer.Close()
configStr := fmt.Sprintf(`
logLevel: error
peers:
test-peer:
proxy: %s
models:
- peer-model
models:
local-model:
cmd: %s -port ${PORT} -silent -respond local-model
`, peerServer.URL, getSimpleResponderPath())
testConfig, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(testConfig)
defer proxy.StopProcesses(StopImmediately)
reqBody := `{"model":"peer-model"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering"))
})
}
+17 -4
View File
@@ -10,6 +10,7 @@ export interface Model {
name: string; name: string;
description: string; description: string;
unlisted: boolean; unlisted: boolean;
peerID: string;
} }
interface APIProviderType { interface APIProviderType {
@@ -70,7 +71,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
const [versionInfo, setVersionInfo] = useState<VersionInfo>({ const [versionInfo, setVersionInfo] = useState<VersionInfo>({
build_date: "unknown", build_date: "unknown",
commit: "unknown", commit: "unknown",
version: "unknown" version: "unknown",
}); });
//const apiEventSource = useRef<EventSource | null>(null); //const apiEventSource = useRef<EventSource | null>(null);
@@ -166,7 +167,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
}, []); }, []);
useEffect(() => { useEffect(() => {
// fetch version // fetch version
const fetchVersion = async () => { const fetchVersion = async () => {
try { try {
const response = await fetch("/api/version"); const response = await fetch("/api/version");
@@ -180,7 +181,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
} }
}; };
if (connectionStatus === 'connected') { if (connectionStatus === "connected") {
fetchVersion(); fetchVersion();
} }
}, [connectionStatus]); }, [connectionStatus]);
@@ -265,7 +266,19 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
connectionStatus, connectionStatus,
versionInfo, versionInfo,
}), }),
[models, listModels, unloadAllModels, unloadSingleModel, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics, connectionStatus, versionInfo] [
models,
listModels,
unloadAllModels,
unloadSingleModel,
loadModel,
enableAPIEvents,
proxyLogs,
upstreamLogs,
metrics,
connectionStatus,
versionInfo,
]
); );
return <APIContext.Provider value={value}>{children}</APIContext.Provider>; return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
+49 -16
View File
@@ -44,8 +44,24 @@ function ModelsPanel() {
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
const [menuOpen, setMenuOpen] = useState(false); const [menuOpen, setMenuOpen] = useState(false);
const filteredModels = useMemo(() => { const { regularModels, peerModelsByPeerId } = useMemo(() => {
return models.filter((model) => showUnlisted || !model.unlisted); const filtered = models.filter((model) => showUnlisted || !model.unlisted);
const peerModels = filtered.filter((m) => m.peerID);
// Group peer models by peerID
const grouped = peerModels.reduce((acc, model) => {
const peerId = model.peerID || "unknown";
if (!acc[peerId]) {
acc[peerId] = [];
}
acc[peerId].push(model);
return acc;
}, {} as Record<string, typeof peerModels>);
return {
regularModels: filtered.filter((m) => !m.peerID),
peerModelsByPeerId: grouped,
};
}, [models, showUnlisted]); }, [models, showUnlisted]);
const handleUnloadAllModels = useCallback(async () => { const handleUnloadAllModels = useCallback(async () => {
@@ -151,7 +167,7 @@ function ModelsPanel() {
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{filteredModels.map((model) => ( {regularModels.map((model) => (
<tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200"> <tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200">
<td className={`${model.unlisted ? "text-txtsecondary" : ""}`}> <td className={`${model.unlisted ? "text-txtsecondary" : ""}`}>
<a href={`/upstream/${model.id}/`} className="font-semibold" target="_blank"> <a href={`/upstream/${model.id}/`} className="font-semibold" target="_blank">
@@ -186,6 +202,34 @@ function ModelsPanel() {
))} ))}
</tbody> </tbody>
</table> </table>
{Object.keys(peerModelsByPeerId).length > 0 && (
<>
<h3 className="mt-8 mb-2">Peer Models</h3>
{Object.entries(peerModelsByPeerId)
.sort(([a], [b]) => a.localeCompare(b))
.map(([peerId, models]) => (
<div key={peerId} className="mb-4">
<table className="w-full">
<thead className="sticky top-0 bg-card z-10">
<tr className="text-left border-b border-gray-200 dark:border-white/10 bg-surface">
<th className="font-semibold">{peerId}</th>
</tr>
</thead>
<tbody>
{models.map((model) => (
<tr key={model.id} className="border-b hover:bg-secondary-hover border-gray-200">
<td className={`pl-8 ${model.unlisted ? "text-txtsecondary" : ""}`}>
<span>{model.id}</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
))}
</>
)}
</div> </div>
</div> </div>
); );
@@ -223,11 +267,7 @@ function TokenHistogram({ data }: { data: HistogramData }) {
return ( return (
<div className="mt-2 w-full"> <div className="mt-2 w-full">
<svg <svg viewBox={`0 0 ${viewBoxWidth} ${height}`} className="w-full h-auto" preserveAspectRatio="xMidYMid meet">
viewBox={`0 0 ${viewBoxWidth} ${height}`}
className="w-full h-auto"
preserveAspectRatio="xMidYMid meet"
>
{/* Y-axis */} {/* Y-axis */}
<line <line
x1={padding.left} x1={padding.left}
@@ -312,14 +352,7 @@ function TokenHistogram({ data }: { data: HistogramData }) {
/> />
{/* X-axis labels */} {/* X-axis labels */}
<text <text x={padding.left} y={height - 5} fontSize="10" fill="currentColor" opacity="0.6" textAnchor="start">
x={padding.left}
y={height - 5}
fontSize="10"
fill="currentColor"
opacity="0.6"
textAnchor="start"
>
{min.toFixed(1)} {min.toFixed(1)}
</text> </text>