Compare commits

...

45 Commits

Author SHA1 Message Date
Benson Wong 5dc6b3e6d9 Add barebones but working implementation of model preload (#209, #235)
Add barebones but working implementation of model preload

* add config test for Preload hook
* improve TestProxyManager_StartupHooks
* docs for new hook configuration
* add a .dev to .gitignore
2025-08-14 10:27:28 -07:00
Benson Wong 74c69f39ef Add prompt processing metrics (#250)
- capture prompt processing metrics
- display prompt processing metrics on UI Activity page
2025-08-14 10:02:16 -07:00
Benson Wong a186318892 Update Readme, Add screenshot for Activities page [skip ci] 2025-08-08 13:39:46 -07:00
Benson Wong c4e4d5e1e9 Update Readme UI Screenshot [skip ci] 2025-08-08 13:33:47 -07:00
Benson Wong 7985e94ba4 add tokens processed to ui models page 2025-08-08 13:28:39 -07:00
Benson Wong 74556c3a36 Update bug-report.md [skip ci] 2025-08-08 09:52:05 -07:00
Benson Wong 5c381e4b30 Add gofmt linting to ci 2025-08-07 20:29:18 -07:00
Benson Wong 10569ed546 Fix model alias usage in upstream path (#230)
Model alias values are not properly resolved and work in upstream/ path.

Related to #229.
2025-08-07 20:16:56 -07:00
Benson Wong 5b10b3c23f UI Tweaks (#228)
* sort model names in UI

* add toggle to show model id/name on UI model page
2025-08-07 11:07:03 -07:00
Benson Wong 45ea792a3a Fix UI panel not saving position correctly 2025-08-06 14:02:22 -07:00
Benson Wong 1bc2802353 fix panels not saving sizing state 2025-08-06 14:00:21 -07:00
Benson Wong 701476c0c4 Update README.md - remove contributor block [skip ci]
Contributor information available on the Github page's sidebar. Redundant.
2025-08-06 11:11:47 -07:00
Ben Greene 5c63e0066c return models sorted by id in /v1/models (#222) 2025-08-06 10:04:52 -07:00
Martin Garton 8be5073c51 Fix typo (#223) [skip ci]
Fix typo `lama-swap` -> `llama-swap`
2025-08-06 10:02:38 -07:00
Aaron Ang 6307bd3205 Add support for building Linux ARM64 binary in Makefile (#221) 2025-08-05 16:26:06 -07:00
Benson Wong 558a72de17 UI Improvements (#219)
- use react-resizable-panels for UI
- improve icons for buttons
- improve mobile layout with drag/resize panels
2025-08-03 17:49:13 -07:00
Leoyzen dc42cf366d Add config monitor support for k8s configmap. (#217) 2025-08-03 08:05:48 -07:00
Ryein Goddard ba0a81937a Update README.md (#216)
Update git clone protocol to https
2025-08-01 19:48:09 -07:00
Benson Wong 574fdfabb4 UI improvements (#213)
* use two column for logs view on wider screens

* hide log controls when panel is minimized
2025-07-31 11:59:21 -07:00
Benson Wong 5172cb2e12 Update docs in Readme [skip ci] 2025-07-30 11:51:14 -07:00
Benson Wong 5672cb03fd Update github actions for notifying homebrew build (#212)
Combine homebrew-llama-swap event with the release action
2025-07-30 11:29:03 -07:00
Benson Wong 0f583163f7 add /health (#211) 2025-07-30 10:37:10 -07:00
Benson Wong 7905fa9ea3 Update trigger-homebrew-update.yml [skip ci] 2025-07-30 10:13:49 -07:00
Ian Sebastian Mathew bbaf172956 add trigger to rebuild homebrew formula (#210) 2025-07-30 10:12:21 -07:00
Benson Wong fd50932dbc Decouple MetricsMiddleware from downstream handlers (#206)
* Decouple MetricsMiddleware from downstream handlers

Remove ls-real-model-name optimization. Within proxyOAIHandler the
request body's bytes are required for various rewriting features
anyways. This negated any benefits from trying not to parse it twice.
2025-07-27 10:36:06 -07:00
Gaël James 8c693e7fcf Add endpoint aliases for reranking models (#201)
* Add endpoint aliases for reranking models
* Add MetricsMiddleware to the previous reranking endpoint
* Fix the embeddings endpoint not having model set
2025-07-24 08:32:47 -07:00
Benson Wong 8f2af26a41 fix stats on model page 2025-07-23 13:57:33 -07:00
Benson Wong 01d4838fb3 Fix token metrics parsing (#199)
Fix #198

- use llama-server's `timings` info if available in response body
- send "-1" for token/sec when not able to accurately calculate
  performance
- optimize streaming body search for metrics information
2025-07-22 23:10:14 -07:00
Benson Wong accd65294b add contributors to README [skip ci] 2025-07-21 23:16:48 -07:00
Benson Wong 7472a25864 Update README.md [skip ci]
update screenshot for web UI
2025-07-21 23:08:19 -07:00
Benson Wong cce0bc6aa1 add guard to ensure ls-real-model-name is set in context 2025-07-21 22:59:41 -07:00
Benson Wong 36e25125e8 UI tidy [skip ci] 2025-07-21 22:47:55 -07:00
Benson Wong 9a54273d15 Update UI with new Activity event stream from #195
- use new metrics data instead of log parsing
- auto-start events connection to server, improves responsiveness
- remove unnecessary libraries and code
2025-07-21 22:42:30 -07:00
g2mt 87dce5f8f6 Add metrics logging for chat completion requests (#195)
- Add token and performance metrics  for v1/chat/completions 
- Add Activity Page in UI
- Add /api/metrics endpoint

Contributed by @g2mt
2025-07-21 22:19:55 -07:00
Benson Wong 307e619521 remove old eventsources from UI 2025-07-19 15:36:40 -07:00
Benson Wong 6299c1b874 Fix High CPU (#189)
* vendor in kelindar/event lib and refactor to remove time.Ticker
2025-07-15 18:04:30 -07:00
Yathi a906cd459b Strip comments before macro expansion in config (#193)
A bug fix that ensures comments don't interfere with macro expansion by
removing them first. This prevents unwanted comment text from appearing
in the final expanded command.

Co-authored-by: Yathiraj Bollimbala G <yathi@yStudio.localdomain>
2025-07-15 10:14:16 -07:00
Benson Wong 78b2bc3dbc add toggle to hide/show unlisted models (#187) 2025-07-02 16:14:20 -07:00
Benson Wong 6a058e4191 Change fsnotify to watch config directory instead of file
The fsnotify library suggests watching a directory and checking that the
name matches the configuration file.
2025-07-02 10:23:52 -07:00
Benson Wong 1921e570d7 Add Event Bus (#184)
Major internal refactor to use an event bus to pass event/messages along. These changes are largely invisible user facing but sets up internal design for real time stats and information.

- `--watch-config` logic refactored for events
- remove multiple SSE api endpoints, replaced with /api/events
- keep all functionality essentially the same
- UI/backend sync is in near real time now
2025-07-01 22:17:35 -07:00
Benson Wong c867a6c9a2 Add name and description to v1/models list (#179)
* Add support for name and description in v1/models list
* add configuration example for name and description
2025-06-30 23:02:44 -07:00
Leoyzen 3bd1b23ce0 fix config hot-reload on k8s (#181)
Co-authored-by: Leoyzen <leoyzen@gmial.com>
2025-06-27 11:49:31 -07:00
srevn 10606abf89 fix config hot-reload on macos (#180)
Co-authored-by: srevn <srevn@github>
2025-06-26 09:20:50 -07:00
Benson Wong fefd14903d improve log display and add a small stats table in ui (#178) 2025-06-25 12:27:49 -07:00
Benson Wong 717d64e336 update GUI image in README [skip ci] 2025-06-24 10:38:28 -07:00
42 changed files with 2585 additions and 564 deletions
+3 -1
View File
@@ -1,11 +1,13 @@
--- ---
name: Bug Report name: Bug Report
about: Something is not working as expected... about: I found a defect
title: '' title: ''
labels: bug labels: bug
assignees: '' assignees: ''
--- ---
> [!IMPORTANT]
> If you have questions about llama-swap please post in the Q&A in Discussions. Use bug reports when you've found a defect and wish to discuss a fix.
**Describe the bug** **Describe the bug**
A clear and concise description of what the bug is. A clear and concise description of what the bug is.
+7
View File
@@ -22,6 +22,13 @@ jobs:
with: with:
go-version: '1.23' go-version: '1.23'
# Only run in this linux based runner
- name: Check Formatting
run: |
if [ "$(gofmt -l . | grep -v 'event/.*_test.go' | wc -l)" -gt 0 ]; then
gofmt -l . | grep -v 'event/.*_test.go'
exit 1
fi
# cache simple-responder to save the build time # cache simple-responder to save the build time
- name: Restore Simple Responder - name: Restore Simple Responder
id: restore-simple-responder id: restore-simple-responder
+33 -3
View File
@@ -7,6 +7,10 @@ on:
# Allows manual triggering of the workflow # Allows manual triggering of the workflow
workflow_dispatch: workflow_dispatch:
inputs:
tag:
description: 'Tag version to release (e.g. v144)'
required: true
permissions: permissions:
contents: write contents: write
@@ -20,15 +24,15 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
ref: ${{ github.event.inputs.tag || github.ref }}
- -
name: Set up Go name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
- -
name: Set up Node.js name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v4
with: with:
node-version: '23' # or your preferred version node-version: '23'
- -
name: Install dependencies and build UI name: Install dependencies and build UI
run: | run: |
@@ -46,4 +50,30 @@ jobs:
version: '~> v2' version: '~> v2'
args: release --clean args: release --clean
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
trigger-tap-update:
runs-on: ubuntu-latest
needs: goreleaser
steps:
- name: "Resolve tag to dispatch"
id: tag
run: |
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
echo "tag=${{ github.event.inputs.tag }}" >> "$GITHUB_OUTPUT"
else
echo "tag=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
fi
- name: "Trigger tap repository update"
uses: peter-evans/repository-dispatch@v2
with:
token: ${{ secrets.TAP_REPO_PAT }}
repository: mostlygeek/homebrew-llama-swap
event-type: new-release
client-payload: |
{
"release": {
"tag_name": "${{ steps.tag.outputs.tag }}"
}
}
+1
View File
@@ -4,3 +4,4 @@ build/
dist/ dist/
.vscode .vscode
.DS_Store .DS_Store
.dev/
+1
View File
@@ -45,6 +45,7 @@ mac: ui
linux: ui linux: ui
@echo "Building Linux binary..." @echo "Building Linux binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64 GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
# Build Windows binary # Build Windows binary
windows: ui windows: ui
+41 -14
View File
@@ -18,7 +18,7 @@ Written in golang, it is very easy to install (single binary with no dependencie
- `v1/completions` - `v1/completions`
- `v1/chat/completions` - `v1/chat/completions`
- `v1/embeddings` - `v1/embeddings`
- `v1/rerank` - `v1/rerank`, `v1/reranking`, `rerank`
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) - `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
- ✅ llama-swap custom API endpoints - ✅ llama-swap custom API endpoints
@@ -27,23 +27,25 @@ Written in golang, it is very easy to install (single binary with no dependencie
- `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) - `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58)) - `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/health` - just returns "OK"
- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- ✅ Automatic unloading of models after timeout by setting a `ttl` - ✅ Automatic unloading of models after timeout by setting a `ttl`
- ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc)
- ✅ Docker and Podman support - Reliable Docker and Podman support with `cmdStart` and `cmdStop`
- ✅ Full control over server settings per model - ✅ Full control over server settings per model
- ✅ Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
## How does llama-swap work? ## How does llama-swap work?
When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request. When a request is made to an OpenAI compatible endpoint, llama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request.
In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used. In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used.
## config.yaml ## config.yaml
llama-swap is managed entirely through a yaml configuration file. llama-swap is managed entirely through a yaml configuration file.
It can be very minimal to start: It can be very minimal to start:
```yaml ```yaml
models: models:
@@ -54,7 +56,7 @@ models:
--port ${PORT} --port ${PORT}
``` ```
However, there are many more capabilities that llama-swap supports: However, there are many more capabilities that llama-swap supports:
- `groups` to run multiple models at once - `groups` to run multiple models at once
- `ttl` to automatically unload models - `ttl` to automatically unload models
@@ -70,15 +72,26 @@ See the [configuration documentation](https://github.com/mostlygeek/llama-swap/w
## Web UI ## Web UI
llama-swap ships with a web based interface to make it easier to monitor logs and check the status of models. llama-swap includes a real time web interface for monitoring logs and models:
<img width="1854" alt="image" src="https://github.com/user-attachments/assets/ee0025f0-f031-4158-9b5d-cd98b2b9fe4d" /> <img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/adef4a8e-de0b-49db-885a-8f6dedae6799" />
The Activity Page shows recent requests:
<img width="1360" height="963" alt="image" src="https://github.com/user-attachments/assets/5f3edee6-d03a-4ae5-ae06-b20ac1f135bd" />
## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap)) ## Installation
Docker is the quickest way to try out llama-swap: llama-swap can be installed in multiple ways
1. Docker
2. Homebrew (OSX and Linux)
3. From release binaries
4. From source
### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap))
Docker images with llama-swap and llama-server are built nightly.
```shell ```shell
# use CPU inference comes with the example config above # use CPU inference comes with the example config above
@@ -100,7 +113,7 @@ $ curl -s http://localhost:9292/v1/chat/completions \
``` ```
<details> <details>
<summary>Docker images are built nightly for cuda, intel, vulcan, etc ...</summary> <summary>Docker images are built nightly with llama-server for cuda, intel, vulcan and musa.</summary>
They include: They include:
@@ -123,9 +136,23 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
</details> </details>
## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases)) ### Homebrew Install (macOS/Linux)
Pre-built binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server. The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh).
```shell
# Set up tap and install formula
brew tap mostlygeek/llama-swap
brew install llama-swap
# Run llama-swap
llama-swap --config path/to/config.yaml --listen localhost:8080
```
This will install the `llama-swap` binary and make it available in your path. See the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration)
### Pre-built Binaries ([download](https://github.com/mostlygeek/llama-swap/releases))
Binaries are available for Linux, Mac, Windows and FreeBSD. These are automatically published and are likely a few hours ahead of the docker releases. The binary install works with any OpenAI compatible server, not just llama-server.
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture. 1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration). 1. Create a configuration file, see the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration).
@@ -139,7 +166,7 @@ Pre-built binaries are available for Linux, Mac, Windows and FreeBSD. These are
### Building from source ### Building from source
1. Build requires golang and nodejs for the user interface. 1. Build requires golang and nodejs for the user interface.
1. `git clone git@github.com:mostlygeek/llama-swap.git` 1. `git clone https://github.com/mostlygeek/llama-swap.git`
1. `make clean all` 1. `make clean all`
1. Binaries will be in `build/` subdirectory 1. Binaries will be in `build/` subdirectory
+43 -2
View File
@@ -1,6 +1,13 @@
# llama-swap YAML configuration example # llama-swap YAML configuration example
# ------------------------------------- # -------------------------------------
# #
# 💡 Tip - Use an LLM with this file!
# ====================================
# This example configuration is written to be LLM friendly! Try
# copying this file into an LLM and asking it to explain or generate
# sections for you.
# ====================================
#
# - Below are all the available configuration options for llama-swap. # - Below are all the available configuration options for llama-swap.
# - Settings with a default value, or noted as optional can be omitted. # - Settings with a default value, or noted as optional can be omitted.
# - Settings that are marked required must be in your configuration file # - Settings that are marked required must be in your configuration file
@@ -15,6 +22,12 @@ healthCheckTimeout: 500
# - Valid log levels: debug, info, warn, error # - Valid log levels: debug, info, warn, error
logLevel: info logLevel: info
# metricsMaxInMemory: maximum number of metrics to keep in memory
# - optional, default: 1000
# - controls how many metrics are stored in memory before older ones are discarded
# - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000
# startPort: sets the starting port number for the automatic ${PORT} macro. # startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800 # - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings # - the ${PORT} macro can be used in model.cmd and model.proxy settings
@@ -49,7 +62,19 @@ models:
cmd: | cmd: |
# ${latest-llama} is a macro that is defined above # ${latest-llama} is a macro that is defined above
${latest-llama} ${latest-llama}
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf --model path/to/llama-8B-Q4_K_M.gguf
# name: a display name for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
name: "llama 3.1 8B"
# description: a description for the model
# - optional, default: empty string
# - if set, it will be used in the v1/models API response
# - if not set, it will be omitted in the JSON model record
description: "A small but capable model used for quick testing"
# env: define an array of environment variables to inject into cmd's environment # env: define an array of environment variables to inject into cmd's environment
# - optional, default: empty array # - optional, default: empty array
@@ -188,4 +213,20 @@ groups:
members: members:
- "forever-modelA" - "forever-modelA"
- "forever-modelB" - "forever-modelB"
- "forever-modelc" - "forever-modelc"
# hooks: a dictionary of event triggers and actions
# - optional, default: empty dictionary
# - the only supported hook is on_startup
hooks:
# on_startup: a dictionary of actions to perform on startup
# - optional, default: empty dictionar
# - the only supported action is preload
on_startup:
# preload: a list of model ids to load on startup
# - optional, default: empty list
# - model names must match keys in the models sections
# - when preloading multiple models at once, define a group
# otherwise models will be loaded and swapped out
preload:
- "llama"
+1
View File
@@ -1,5 +1,6 @@
healthCheckTimeout: 300 healthCheckTimeout: 300
logRequests: true logRequests: true
metricsMaxInMemory: 1000
models: models:
"qwen2.5": "qwen2.5":
+3
View File
@@ -0,0 +1,3 @@
The code in `event` was originally a part of https://github.com/kelindar/event (v1.5.2)
The original code uses a `time.Ticker` to process the event queue which caused a large increase in CPU usage ([#189](https://github.com/mostlygeek/llama-swap/issues/189)). This code was ported to remove the ticker and instead be more event driven.
+30
View File
@@ -0,0 +1,30 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"context"
)
// Default initializes a default in-process dispatcher
var Default = NewDispatcherConfig(25000)
// On subscribes to an event, the type of the event will be automatically
// inferred from the provided type. Must be constant for this to work. This
// functions same way as Subscribe() but uses the default dispatcher instead.
func On[T Event](handler func(T)) context.CancelFunc {
return Subscribe(Default, handler)
}
// OnType subscribes to an event with the specified event type. This functions
// same way as SubscribeTo() but uses the default dispatcher instead.
func OnType[T Event](eventType uint32, handler func(T)) context.CancelFunc {
return SubscribeTo(Default, eventType, handler)
}
// Emit writes an event into the dispatcher. This functions same way as
// Publish() but uses the default dispatcher instead.
func Emit[T Event](ev T) {
Publish(Default, ev)
}
+54
View File
@@ -0,0 +1,54 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
)
/*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkSubcribeConcurrent-24 1826686 606.3 ns/op 1648 B/op 5 allocs/op
*/
func BenchmarkSubscribeConcurrent(b *testing.B) {
d := NewDispatcher()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
unsub := Subscribe(d, func(ev MyEvent1) {})
unsub()
}
})
}
func TestDefaultPublish(t *testing.T) {
var wg sync.WaitGroup
// Subscribe
var count int64
defer On(func(ev MyEvent1) {
atomic.AddInt64(&count, 1)
wg.Done()
})()
defer OnType(TypeEvent1, func(ev MyEvent1) {
atomic.AddInt64(&count, 1)
wg.Done()
})()
// Publish
wg.Add(4)
Emit(MyEvent1{})
Emit(MyEvent1{})
// Wait and check
wg.Wait()
assert.Equal(t, int64(4), count)
}
+324
View File
@@ -0,0 +1,324 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.
package event
import (
"context"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"sync/atomic"
)
// Event represents an event contract
type Event interface {
Type() uint32
}
// registry holds an immutable sorted array of event mappings
type registry struct {
keys []uint32 // Event types (sorted)
grps []any // Corresponding subscribers
}
// ------------------------------------- Dispatcher -------------------------------------
// Dispatcher represents an event dispatcher.
type Dispatcher struct {
subs atomic.Pointer[registry] // Atomic pointer to immutable array
done chan struct{} // Cancellation
maxQueue int // Maximum queue size per consumer
mu sync.Mutex // Only for writes (subscribe/unsubscribe)
}
// NewDispatcher creates a new dispatcher of events.
func NewDispatcher() *Dispatcher {
return NewDispatcherConfig(50000)
}
// NewDispatcherConfig creates a new dispatcher with configurable max queue size
func NewDispatcherConfig(maxQueue int) *Dispatcher {
d := &Dispatcher{
done: make(chan struct{}),
maxQueue: maxQueue,
}
d.subs.Store(&registry{
keys: make([]uint32, 0, 16),
grps: make([]any, 0, 16),
})
return d
}
// Close closes the dispatcher
func (d *Dispatcher) Close() error {
close(d.done)
return nil
}
// isClosed returns whether the dispatcher is closed or not
func (d *Dispatcher) isClosed() bool {
select {
case <-d.done:
return true
default:
return false
}
}
// findGroup performs a lock-free binary search for the event type
func (d *Dispatcher) findGroup(eventType uint32) any {
reg := d.subs.Load()
keys := reg.keys
// Inlined binary search for better cache locality
left, right := 0, len(keys)
for left < right {
mid := left + (right-left)/2
if keys[mid] < eventType {
left = mid + 1
} else {
right = mid
}
}
if left < len(keys) && keys[left] == eventType {
return reg.grps[left]
}
return nil
}
// Subscribe subscribes to an event, the type of the event will be automatically
// inferred from the provided type. Must be constant for this to work.
func Subscribe[T Event](broker *Dispatcher, handler func(T)) context.CancelFunc {
var event T
return SubscribeTo(broker, event.Type(), handler)
}
// SubscribeTo subscribes to an event with the specified event type.
func SubscribeTo[T Event](broker *Dispatcher, eventType uint32, handler func(T)) context.CancelFunc {
if broker.isClosed() {
panic(errClosed)
}
broker.mu.Lock()
defer broker.mu.Unlock()
// Check if group already exists
if existing := broker.findGroup(eventType); existing != nil {
grp := groupOf[T](eventType, existing)
sub := grp.Add(handler)
return func() {
grp.Del(sub)
}
}
// Create new group
grp := &group[T]{cond: sync.NewCond(new(sync.Mutex)), maxQueue: broker.maxQueue}
sub := grp.Add(handler)
// Copy-on-write: insert new entry in sorted position
old := broker.subs.Load()
idx := sort.Search(len(old.keys), func(i int) bool {
return old.keys[i] >= eventType
})
// Create new arrays with space for one more element
newKeys := make([]uint32, len(old.keys)+1)
newGrps := make([]any, len(old.grps)+1)
// Copy elements before insertion point
copy(newKeys[:idx], old.keys[:idx])
copy(newGrps[:idx], old.grps[:idx])
// Insert new element
newKeys[idx] = eventType
newGrps[idx] = grp
// Copy elements after insertion point
copy(newKeys[idx+1:], old.keys[idx:])
copy(newGrps[idx+1:], old.grps[idx:])
// Atomically store the new registry (mutex ensures no concurrent writers)
newReg := &registry{keys: newKeys, grps: newGrps}
broker.subs.Store(newReg)
return func() {
grp.Del(sub)
}
}
// Publish writes an event into the dispatcher
func Publish[T Event](broker *Dispatcher, ev T) {
eventType := ev.Type()
if sub := broker.findGroup(eventType); sub != nil {
group := groupOf[T](eventType, sub)
group.Broadcast(ev)
}
}
// Count counts the number of subscribers, this is for testing only.
func (d *Dispatcher) count(eventType uint32) int {
if group := d.findGroup(eventType); group != nil {
return group.(interface{ Count() int }).Count()
}
return 0
}
// groupOf casts the subscriber group to the specified generic type
func groupOf[T Event](eventType uint32, subs any) *group[T] {
if group, ok := subs.(*group[T]); ok {
return group
}
panic(errConflict[T](eventType, subs))
}
// ------------------------------------- Subscriber -------------------------------------
// consumer represents a consumer with a message queue
type consumer[T Event] struct {
queue []T // Current work queue
stop bool // Stop signal
}
// Listen listens to the event queue and processes events
func (s *consumer[T]) Listen(c *sync.Cond, fn func(T)) {
pending := make([]T, 0, 128)
for {
c.L.Lock()
for len(s.queue) == 0 {
switch {
case s.stop:
c.L.Unlock()
return
default:
c.Wait()
}
}
// Swap buffers and reset the current queue
temp := s.queue
s.queue = pending[:0]
pending = temp
c.L.Unlock()
// Outside of the critical section, process the work
for _, event := range pending {
fn(event)
}
// Notify potential publishers waiting due to backpressure
c.Broadcast()
}
}
// ------------------------------------- Subscriber Group -------------------------------------
// group represents a consumer group
type group[T Event] struct {
cond *sync.Cond
subs []*consumer[T]
maxQueue int // Maximum queue size per consumer
maxLen int // Current maximum queue length across all consumers
}
// Broadcast sends an event to all consumers
func (s *group[T]) Broadcast(ev T) {
s.cond.L.Lock()
defer s.cond.L.Unlock()
// Calculate current maximum queue length
s.maxLen = 0
for _, sub := range s.subs {
if len(sub.queue) > s.maxLen {
s.maxLen = len(sub.queue)
}
}
// Backpressure: wait if queues are full
for s.maxLen >= s.maxQueue {
s.cond.Wait()
// Recalculate after wakeup
s.maxLen = 0
for _, sub := range s.subs {
if len(sub.queue) > s.maxLen {
s.maxLen = len(sub.queue)
}
}
}
// Add event to all queues and track new maximum
newMax := 0
for _, sub := range s.subs {
sub.queue = append(sub.queue, ev)
if len(sub.queue) > newMax {
newMax = len(sub.queue)
}
}
s.maxLen = newMax
s.cond.Broadcast() // Wake consumers
}
// Add adds a subscriber to the list
func (s *group[T]) Add(handler func(T)) *consumer[T] {
sub := &consumer[T]{
queue: make([]T, 0, 64),
}
// Add the consumer to the list of active consumers
s.cond.L.Lock()
s.subs = append(s.subs, sub)
s.cond.L.Unlock()
// Start listening
go sub.Listen(s.cond, handler)
return sub
}
// Del removes a subscriber from the list
func (s *group[T]) Del(sub *consumer[T]) {
s.cond.L.Lock()
defer s.cond.L.Unlock()
// Search and remove the subscriber
sub.stop = true
for i, v := range s.subs {
if v == sub {
copy(s.subs[i:], s.subs[i+1:])
s.subs = s.subs[:len(s.subs)-1]
break
}
}
}
// ------------------------------------- Debugging -------------------------------------
var errClosed = fmt.Errorf("event dispatcher is closed")
// Count returns the number of subscribers in this group
func (s *group[T]) Count() int {
return len(s.subs)
}
// String returns string representation of the type
func (s *group[T]) String() string {
typ := reflect.TypeOf(s).String()
idx := strings.LastIndex(typ, "/")
typ = typ[idx+1 : len(typ)-1]
return typ
}
// errConflict returns a conflict message
func errConflict[T any](eventType uint32, existing any) string {
var want T
return fmt.Sprintf(
"conflicting event type, want=<%T>, registered=<%s>, event=0x%v",
want, existing, eventType,
)
}
+324
View File
@@ -0,0 +1,324 @@
// Copyright (c) Roman Atachiants and contributore. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for detaile.
package event
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPublish(t *testing.T) {
d := NewDispatcher()
var wg sync.WaitGroup
// Subscribe, must be received in order
var count int64
defer Subscribe(d, func(ev MyEvent1) {
assert.Equal(t, int(atomic.AddInt64(&count, 1)), ev.Number)
wg.Done()
})()
// Publish
wg.Add(3)
Publish(d, MyEvent1{Number: 1})
Publish(d, MyEvent1{Number: 2})
Publish(d, MyEvent1{Number: 3})
// Wait and check
wg.Wait()
assert.Equal(t, int64(3), count)
}
func TestUnsubscribe(t *testing.T) {
d := NewDispatcher()
assert.Equal(t, 0, d.count(TypeEvent1))
unsubscribe := Subscribe(d, func(ev MyEvent1) {
// Nothing
})
assert.Equal(t, 1, d.count(TypeEvent1))
unsubscribe()
assert.Equal(t, 0, d.count(TypeEvent1))
}
func TestConcurrent(t *testing.T) {
const max = 1000000
var count int64
var wg sync.WaitGroup
wg.Add(1)
d := NewDispatcher()
defer Subscribe(d, func(ev MyEvent1) {
if current := atomic.AddInt64(&count, 1); current == max {
wg.Done()
}
})()
// Asynchronously publish
go func() {
for i := 0; i < max; i++ {
Publish(d, MyEvent1{})
}
}()
defer Subscribe(d, func(ev MyEvent1) {
// Subscriber that does nothing
})()
wg.Wait()
assert.Equal(t, max, int(count))
}
func TestSubscribeDifferentType(t *testing.T) {
d := NewDispatcher()
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent1) {})
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
})
}
func TestPublishDifferentType(t *testing.T) {
d := NewDispatcher()
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
Publish(d, MyEvent1{})
})
}
func TestCloseDispatcher(t *testing.T) {
d := NewDispatcher()
defer SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})()
assert.NoError(t, d.Close())
assert.Panics(t, func() {
SubscribeTo(d, TypeEvent1, func(ev MyEvent2) {})
})
}
func TestMatrix(t *testing.T) {
const amount = 1000
for _, subs := range []int{1, 10, 100} {
for _, topics := range []int{1, 10} {
expected := subs * topics * amount
t.Run(fmt.Sprintf("%dx%d", topics, subs), func(t *testing.T) {
var count atomic.Int64
var wg sync.WaitGroup
wg.Add(expected)
d := NewDispatcher()
for i := 0; i < subs; i++ {
for id := 0; id < topics; id++ {
defer SubscribeTo(d, uint32(id), func(ev MyEvent3) {
count.Add(1)
wg.Done()
})()
}
}
for n := 0; n < amount; n++ {
for id := 0; id < topics; id++ {
go Publish(d, MyEvent3{ID: id})
}
}
wg.Wait()
assert.Equal(t, expected, int(count.Load()))
})
}
}
}
func TestConcurrentSubscriptionRace(t *testing.T) {
// This test specifically targets the race condition that occurs when multiple
// goroutines try to subscribe to different event types simultaneously.
// Without the CAS loop, subscriptions could be lost due to registry corruption.
const numGoroutines = 100
const numEventTypes = 50
d := NewDispatcher()
defer d.Close()
var wg sync.WaitGroup
var receivedCount int64
var subscribedTypes sync.Map // Thread-safe map
wg.Add(numGoroutines)
// Start multiple goroutines that subscribe to different event types concurrently
for i := 0; i < numGoroutines; i++ {
go func(goroutineID int) {
defer wg.Done()
// Each goroutine subscribes to a unique event type
eventType := uint32(goroutineID%numEventTypes + 1000) // Offset to avoid collision with other tests
// Subscribe to the event type
SubscribeTo(d, eventType, func(ev MyEvent3) {
atomic.AddInt64(&receivedCount, 1)
})
// Record that this type was subscribed
subscribedTypes.Store(eventType, true)
}(i)
}
// Wait for all subscriptions to complete
wg.Wait()
// Count the number of unique event types subscribed
expectedTypes := 0
subscribedTypes.Range(func(key, value interface{}) bool {
expectedTypes++
return true
})
// Small delay to ensure all subscriptions are fully processed
time.Sleep(10 * time.Millisecond)
// Publish events to each subscribed type
subscribedTypes.Range(func(key, value interface{}) bool {
eventType := key.(uint32)
Publish(d, MyEvent3{ID: int(eventType)})
return true
})
// Wait for all events to be processed
time.Sleep(50 * time.Millisecond)
// Verify that we received at least the expected number of events
// (there might be more if multiple goroutines subscribed to the same event type)
received := atomic.LoadInt64(&receivedCount)
assert.GreaterOrEqual(t, int(received), expectedTypes,
"Should have received at least %d events, got %d", expectedTypes, received)
// Verify that we have the expected number of unique event types
assert.Equal(t, numEventTypes, expectedTypes,
"Should have exactly %d unique event types", numEventTypes)
}
func TestConcurrentHandlerRegistration(t *testing.T) {
const numGoroutines = 100
// Test concurrent subscriptions to the same event type
t.Run("SameEventType", func(t *testing.T) {
d := NewDispatcher()
var handlerCount int64
var wg sync.WaitGroup
// Start multiple goroutines subscribing to the same event type (0x1)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
SubscribeTo(d, uint32(0x1), func(ev MyEvent1) {
atomic.AddInt64(&handlerCount, 1)
})
}()
}
wg.Wait()
// Verify all handlers were registered by publishing an event
atomic.StoreInt64(&handlerCount, 0)
Publish(d, MyEvent1{})
// Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond)
assert.Equal(t, int64(numGoroutines), atomic.LoadInt64(&handlerCount),
"Not all handlers were registered due to race condition")
})
// Test concurrent subscriptions to different event types
t.Run("DifferentEventTypes", func(t *testing.T) {
d := NewDispatcher()
var wg sync.WaitGroup
receivedEvents := make(map[uint32]*int64)
// Create multiple event types and subscribe concurrently
for i := 0; i < numGoroutines; i++ {
eventType := uint32(100 + i)
counter := new(int64)
receivedEvents[eventType] = counter
wg.Add(1)
go func(et uint32, cnt *int64) {
defer wg.Done()
SubscribeTo(d, et, func(ev MyEvent3) {
atomic.AddInt64(cnt, 1)
})
}(eventType, counter)
}
wg.Wait()
// Publish events to all types
for eventType := uint32(100); eventType < uint32(100+numGoroutines); eventType++ {
Publish(d, MyEvent3{ID: int(eventType)})
}
// Small delay to ensure all handlers have executed
time.Sleep(10 * time.Millisecond)
// Verify all event types received their events
for eventType, counter := range receivedEvents {
assert.Equal(t, int64(1), atomic.LoadInt64(counter),
"Event type %d did not receive its event", eventType)
}
})
}
func TestBackpressure(t *testing.T) {
d := NewDispatcher()
d.maxQueue = 10
var processedCount int64
unsub := SubscribeTo(d, uint32(0x200), func(ev MyEvent3) {
atomic.AddInt64(&processedCount, 1)
})
defer unsub()
const eventsToPublish = 1000
for i := 0; i < eventsToPublish; i++ {
Publish(d, MyEvent3{ID: 0x200})
}
time.Sleep(100 * time.Millisecond)
// Verify all events were eventually processed
finalProcessed := atomic.LoadInt64(&processedCount)
assert.Equal(t, int64(eventsToPublish), finalProcessed)
t.Logf("Events processed: %d/%d", finalProcessed, eventsToPublish)
}
// ------------------------------------- Test Events -------------------------------------
const (
TypeEvent1 = 0x1
TypeEvent2 = 0x2
)
type MyEvent1 struct {
Number int
}
func (t MyEvent1) Type() uint32 { return TypeEvent1 }
type MyEvent2 struct {
Text string
}
func (t MyEvent2) Type() uint32 { return TypeEvent2 }
type MyEvent3 struct {
ID int
}
func (t MyEvent3) Type() uint32 { return uint32(t.ID) }
+1 -1
View File
@@ -3,6 +3,7 @@ module github.com/mostlygeek/llama-swap
go 1.23.0 go 1.23.0
require ( require (
github.com/billziss-gh/golib v0.2.0
github.com/fsnotify/fsnotify v1.9.0 github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
@@ -12,7 +13,6 @@ require (
) )
require ( require (
github.com/billziss-gh/golib v0.2.0 // indirect
github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/base64x v0.1.4 // indirect
-2
View File
@@ -32,8 +32,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
+110 -111
View File
@@ -14,6 +14,7 @@ import (
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy" "github.com/mostlygeek/llama-swap/proxy"
) )
@@ -53,137 +54,135 @@ func main() {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
proxyManager := proxy.New(config)
// Setup channels for server management // Setup channels for server management
reloadChan := make(chan *proxy.ProxyManager)
exitChan := make(chan struct{}) exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create server with initial handler // Create server with initial handler
srv := &http.Server{ srv := &http.Server{
Addr: *listenStr, Addr: *listenStr,
Handler: proxyManager,
} }
// Support for watching config and reloading when it changes
reloadProxyManager := func() {
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
config, err = proxy.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
return
}
fmt.Println("Configuration Changed")
currentPM.Shutdown()
srv.Handler = proxy.New(config)
fmt.Println("Configuration Reloaded")
// wait a few seconds and tell any UI to reload
time.AfterFunc(3*time.Second, func() {
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateEnd,
})
})
} else {
config, err = proxy.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error, unable to load configuration: %v\n", err)
os.Exit(1)
}
srv.Handler = proxy.New(config)
}
}
// load the initial proxy manager
reloadProxyManager()
debouncedReload := debounce(time.Second, reloadProxyManager)
if *watchConfig {
defer event.On(func(e proxy.ConfigFileChangedEvent) {
if e.ReloadingState == proxy.ReloadingStateStart {
debouncedReload()
}
})()
fmt.Println("Watching Configuration for changes")
go func() {
absConfigPath, err := filepath.Abs(*configPath)
if err != nil {
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
return
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
return
}
configDir := filepath.Dir(absConfigPath)
err = watcher.Add(configDir)
if err != nil {
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
return
}
defer watcher.Close()
for {
select {
case changeEvent := <-watcher.Events:
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
// the change for k8s configmap
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
}
case err := <-watcher.Errors:
log.Printf("File watcher error: %v", err)
}
}
}()
}
// shutdown on signal
go func() {
sig := <-sigChan
fmt.Printf("Received signal %v, shutting down...\n", sig)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
if pm, ok := srv.Handler.(*proxy.ProxyManager); ok {
pm.Shutdown()
} else {
fmt.Println("srv.Handler is not of type *proxy.ProxyManager")
}
if err := srv.Shutdown(ctx); err != nil {
fmt.Printf("Server shutdown error: %v\n", err)
}
close(exitChan)
}()
// Start server // Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr) fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() { go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Printf("Fatal server error: %v\n", err) log.Fatalf("Fatal server error: %v\n", err)
close(exitChan)
} }
}() }()
// Handle config reloads and signals
go func() {
currentManager := proxyManager
for {
select {
case newManager := <-reloadChan:
log.Println("Config change detected, waiting for in-flight requests to complete...")
// Stop old manager processes gracefully (this waits for in-flight requests)
currentManager.StopProcesses(proxy.StopWaitForInflightRequest)
// Now do a full shutdown to clear the process map
currentManager.Shutdown()
currentManager = newManager
srv.Handler = newManager
log.Println("Server handler updated with new config")
case sig := <-sigChan:
fmt.Printf("Received signal %v, shutting down...\n", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentManager.Shutdown()
if err := srv.Shutdown(ctx); err != nil {
fmt.Printf("Server shutdown error: %v\n", err)
}
close(exitChan)
return
}
}
}()
// Start file watcher if requested
if *watchConfig {
absConfigPath, err := filepath.Abs(*configPath)
if err != nil {
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
} else {
go watchConfigFileWithReload(absConfigPath, reloadChan)
}
}
// Wait for exit signal // Wait for exit signal
<-exitChan <-exitChan
} }
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan. func debounce(interval time.Duration, f func()) func() {
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) { var timer *time.Timer
watcher, err := fsnotify.NewWatcher() return func() {
if err != nil { if timer != nil {
log.Printf("Error creating file watcher: %v. File watching disabled.", err) timer.Stop()
return
}
defer watcher.Close()
err = watcher.Add(configPath)
if err != nil {
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
return
}
log.Printf("Watching config file for changes: %s", configPath)
var debounceTimer *time.Timer
debounceDuration := 2 * time.Second
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
// We only care about writes to the specific config file
if event.Name == configPath && event.Has(fsnotify.Write) {
// Reset or start the debounce timer
if debounceTimer != nil {
debounceTimer.Stop()
}
debounceTimer = time.AfterFunc(debounceDuration, func() {
log.Printf("Config file modified: %s, reloading...", event.Name)
// Try up to 3 times with exponential backoff
var newConfig proxy.Config
var err error
for retries := 0; retries < 3; retries++ {
// Load new configuration
newConfig, err = proxy.LoadConfig(configPath)
if err == nil {
break
}
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
if retries < 2 {
time.Sleep(time.Duration(1<<retries) * time.Second)
}
}
if err != nil {
log.Printf("Failed to load new config after retries: %v", err)
return
}
// Create new ProxyManager with new config
newPM := proxy.New(newConfig)
reloadChan <- newPM
log.Println("Config reloaded successfully")
})
}
case err, ok := <-watcher.Errors:
if !ok {
log.Println("File watcher error channel closed.")
return
}
log.Printf("File watcher error: %v", err)
} }
timer = time.AfterFunc(interval, f)
} }
} }
+87 -12
View File
@@ -35,20 +35,90 @@ func main() {
// Set up the handler function using the provided response message // Set up the handler function using the provided response message
r.POST("/v1/chat/completions", func(c *gin.Context) { r.POST("/v1/chat/completions", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
bodyBytes, _ := io.ReadAll(c.Request.Body) bodyBytes, _ := io.ReadAll(c.Request.Body)
c.JSON(http.StatusOK, gin.H{ // Check if streaming is requested
"responseMessage": *responseMessage, // Query is checked instead of JSON body since that event stream conflicts with other tests
"h_content_length": c.Request.Header.Get("Content-Length"), isStreaming := c.Query("stream") == "true"
"request_body": string(bodyBytes),
}) if isStreaming {
// Set headers for streaming
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
// Send 10 "asdf" tokens
for i := 0; i < 10; i++ {
data := gin.H{
"created": time.Now().Unix(),
"choices": []gin.H{
{
"index": 0,
"delta": gin.H{
"content": "asdf",
},
"finish_reason": nil,
},
},
}
c.SSEvent("message", data)
c.Writer.Flush()
}
// Send final data with usage info
finalData := gin.H{
"usage": gin.H{
"completion_tokens": 10,
"prompt_tokens": 25,
"total_tokens": 35,
},
// add timings to simulate llama.cpp
"timings": gin.H{
"prompt_n": 25,
"prompt_ms": 13,
"predicted_n": 10,
"predicted_ms": 17,
"predicted_per_second": 10,
},
}
c.SSEvent("message", finalData)
c.Writer.Flush()
// Send [DONE]
c.SSEvent("message", "[DONE]")
c.Writer.Flush()
} else {
c.Header("Content-Type", "application/json")
// add a wait to simulate a slow query
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
time.Sleep(wait)
}
c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage,
"h_content_length": c.Request.Header.Get("Content-Length"),
"request_body": string(bodyBytes),
"usage": gin.H{
"completion_tokens": 10,
"prompt_tokens": 25,
"total_tokens": 35,
},
"timings": gin.H{
"prompt_n": 25,
"prompt_ms": 13,
"predicted_n": 10,
"predicted_ms": 17,
"predicted_per_second": 10,
},
})
}
}) })
// for issue #62 to check model name strips profile slug // for issue #62 to check model name strips profile slug
@@ -74,6 +144,11 @@ func main() {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"responseMessage": *responseMessage, "responseMessage": *responseMessage,
"usage": gin.H{
"completion_tokens": 10,
"prompt_tokens": 25,
"total_tokens": 35,
},
}) })
}) })
+52
View File
@@ -28,6 +28,10 @@ type ModelConfig struct {
Unlisted bool `yaml:"unlisted"` Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"` UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process // Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"` ConcurrencyLimit int `yaml:"concurrencyLimit"`
@@ -48,6 +52,8 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
Unlisted: false, Unlisted: false,
UseModelName: "", UseModelName: "",
ConcurrencyLimit: 0, ConcurrencyLimit: 0,
Name: "",
Description: "",
} }
// the default cmdStop to taskkill /f /t /pid ${PID} // the default cmdStop to taskkill /f /t /pid ${PID}
@@ -132,10 +138,19 @@ func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
return nil return nil
} }
type HooksConfig struct {
OnStartup HookOnStartup `yaml:"on_startup"`
}
type HookOnStartup struct {
Preload []string `yaml:"preload"`
}
type Config struct { type Config struct {
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
LogRequests bool `yaml:"logRequests"` LogRequests bool `yaml:"logRequests"`
LogLevel string `yaml:"logLevel"` LogLevel string `yaml:"logLevel"`
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */ Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
Profiles map[string][]string `yaml:"profiles"` Profiles map[string][]string `yaml:"profiles"`
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */ Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
@@ -148,6 +163,9 @@ type Config struct {
// automatic port assignments // automatic port assignments
StartPort int `yaml:"startPort"` StartPort int `yaml:"startPort"`
// hooks, see: #209
Hooks HooksConfig `yaml:"hooks"`
} }
func (c *Config) RealModelName(search string) (string, bool) { func (c *Config) RealModelName(search string) (string, bool) {
@@ -188,6 +206,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
HealthCheckTimeout: 120, HealthCheckTimeout: 120,
StartPort: 5800, StartPort: 5800,
LogLevel: "info", LogLevel: "info",
MetricsMaxInMemory: 1000,
} }
err = yaml.Unmarshal(data, &config) err = yaml.Unmarshal(data, &config)
if err != nil { if err != nil {
@@ -249,6 +268,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
for _, modelId := range modelIds { for _, modelId := range modelIds {
modelConfig := config.Models[modelId] modelConfig := config.Models[modelId]
// Strip comments from command fields before macro expansion
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values // go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range config.Macros { for macroName, macroValue := range config.Macros {
macroSlug := fmt.Sprintf("${%s}", macroName) macroSlug := fmt.Sprintf("${%s}", macroName)
@@ -318,6 +341,22 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
} }
} }
// clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string
for _, modelID := range config.Hooks.OnStartup.Preload {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
if real, found := config.RealModelName(modelID); found {
toPreload = append(toPreload, real)
}
}
config.Hooks.OnStartup.Preload = toPreload
}
return config, nil return config, nil
} }
@@ -400,3 +439,16 @@ func SanitizeCommand(cmdStr string) ([]string, error) {
return args, nil return args, nil
} }
func StripComments(cmdStr string) string {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
+13
View File
@@ -100,10 +100,15 @@ func TestConfig_LoadPosix(t *testing.T) {
content := ` content := `
macros: macros:
svr-path: "path/to/server" svr-path: "path/to/server"
hooks:
on_startup:
preload: ["model1", "model2"]
models: models:
model1: model1:
cmd: path/to/cmd --arg1 one cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080" proxy: "http://localhost:8080"
name: "Model 1"
description: "This is model 1"
aliases: aliases:
- "m1" - "m1"
- "model-one" - "model-one"
@@ -161,6 +166,11 @@ groups:
Macros: map[string]string{ Macros: map[string]string{
"svr-path": "path/to/server", "svr-path": "path/to/server",
}, },
Hooks: HooksConfig{
OnStartup: HookOnStartup{
Preload: []string{"model1", "model2"},
},
},
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
Cmd: "path/to/cmd --arg1 one", Cmd: "path/to/cmd --arg1 one",
@@ -168,6 +178,8 @@ groups:
Aliases: []string{"m1", "model-one"}, Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"}, Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health", CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
}, },
"model2": { "model2": {
Cmd: "path/to/server --arg1 one", Cmd: "path/to/server --arg1 one",
@@ -192,6 +204,7 @@ groups:
}, },
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
Profiles: map[string][]string{ Profiles: map[string][]string{
"test": {"model1", "model2"}, "test": {"model1", "model2"},
}, },
+115
View File
@@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"slices"
"strings" "strings"
"testing" "testing"
@@ -325,3 +326,117 @@ models:
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized) assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
} }
} }
func TestStripComments(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no comments",
input: "echo hello\necho world",
expected: "echo hello\necho world",
},
{
name: "single comment line",
input: "# this is a comment\necho hello",
expected: "echo hello",
},
{
name: "multiple comment lines",
input: "# comment 1\necho hello\n# comment 2\necho world",
expected: "echo hello\necho world",
},
{
name: "comment with spaces",
input: " # indented comment\necho hello",
expected: "echo hello",
},
{
name: "empty lines preserved",
input: "echo hello\n\necho world",
expected: "echo hello\n\necho world",
},
{
name: "only comments",
input: "# comment 1\n# comment 2",
expected: "",
},
{
name: "empty string",
input: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := StripComments(tt.input)
if result != tt.expected {
t.Errorf("StripComments() = %q, expected %q", result, tt.expected)
}
})
}
}
func TestConfig_MacroInCommentStrippedBeforeExpansion(t *testing.T) {
// Test case that reproduces the original bug where a macro in a comment
// would get expanded and cause the comment text to be included in the command
content := `
startPort: 9990
macros:
"latest-llama": >
/user/llama.cpp/build/bin/llama-server
--port ${PORT}
models:
"test-model":
cmd: |
# ${latest-llama} is a macro that is defined above
${latest-llama}
--model /path/to/model.gguf
-ngl 99
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
// Get the sanitized command
sanitizedCmd, err := SanitizeCommand(config.Models["test-model"].Cmd)
assert.NoError(t, err)
// Join the command for easier inspection
cmdStr := strings.Join(sanitizedCmd, " ")
// Verify that comment text is NOT present in the final command as separate arguments
commentWords := []string{"is", "macro", "that", "defined", "above"}
for _, word := range commentWords {
found := slices.Contains(sanitizedCmd, word)
assert.False(t, found, "Comment text '%s' should not be present as a separate argument in final command", word)
}
// Verify that the actual command components ARE present
expectedParts := []string{
"/user/llama.cpp/build/bin/llama-server",
"--port",
"9990",
"--model",
"/path/to/model.gguf",
"-ngl",
"99",
}
for _, part := range expectedParts {
assert.Contains(t, cmdStr, part, "Expected command part '%s' not found in final command", part)
}
// Verify the server path appears exactly once (not duplicated due to macro expansion)
serverPath := "/user/llama.cpp/build/bin/llama-server"
count := strings.Count(cmdStr, serverPath)
assert.Equal(t, 1, count, "Expected exactly 1 occurrence of server path, found %d", count)
// Verify the expected final command structure
expectedCmd := "/user/llama.cpp/build/bin/llama-server --port 9990 --model /path/to/model.gguf -ngl 99"
assert.Equal(t, expectedCmd, cmdStr, "Final command does not match expected structure")
}
+1
View File
@@ -193,6 +193,7 @@ groups:
}, },
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
MetricsMaxInMemory: 1000,
Profiles: map[string][]string{ Profiles: map[string][]string{
"test": {"model1", "model2"}, "test": {"model1", "model2"},
}, },
+27
View File
@@ -0,0 +1,27 @@
package proxy
import "net/http"
// Custom discard writer that implements http.ResponseWriter but just discards everything
type DiscardWriter struct {
header http.Header
status int
}
func (w *DiscardWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *DiscardWriter) Write(data []byte) (int, error) {
return len(data), nil
}
func (w *DiscardWriter) WriteHeader(code int) {
w.status = code
}
// Satisfy the http.Flusher interface for streaming responses
func (w *DiscardWriter) Flush() {}
+60
View File
@@ -0,0 +1,60 @@
package proxy
// package level registry of the different event types
const ProcessStateChangeEventID = 0x01
const ChatCompletionStatsEventID = 0x02
const ConfigFileChangedEventID = 0x03
const LogDataEventID = 0x04
const TokenMetricsEventID = 0x05
const ModelPreloadedEventID = 0x06
type ProcessStateChangeEvent struct {
ProcessName string
NewState ProcessState
OldState ProcessState
}
func (e ProcessStateChangeEvent) Type() uint32 {
return ProcessStateChangeEventID
}
type ChatCompletionStats struct {
TokensGenerated int
}
func (e ChatCompletionStats) Type() uint32 {
return ChatCompletionStatsEventID
}
type ReloadingState int
const (
ReloadingStateStart ReloadingState = iota
ReloadingStateEnd
)
type ConfigFileChangedEvent struct {
ReloadingState ReloadingState
}
func (e ConfigFileChangedEvent) Type() uint32 {
return ConfigFileChangedEventID
}
type LogDataEvent struct {
Data []byte
}
func (e LogDataEvent) Type() uint32 {
return LogDataEventID
}
type ModelPreloadedEvent struct {
ModelName string
Success bool
}
func (e ModelPreloadedEvent) Type() uint32 {
return ModelPreloadedEventID
}
+5 -6
View File
@@ -13,9 +13,10 @@ import (
) )
var ( var (
nextTestPort int = 12000 nextTestPort int = 12000
portMutex sync.Mutex portMutex sync.Mutex
testLogger = NewLogMonitorWriter(os.Stdout) testLogger = NewLogMonitorWriter(os.Stdout)
simpleResponderPath = getSimpleResponderPath()
) )
// Check if the binary exists // Check if the binary exists
@@ -69,13 +70,11 @@ func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
} }
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig { func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
binaryPath := getSimpleResponderPath()
// Create a YAML string with just the values we want to set // Create a YAML string with just the values we want to set
yamlStr := fmt.Sprintf(` yamlStr := fmt.Sprintf(`
cmd: '%s --port %d --silent --respond %s' cmd: '%s --port %d --silent --respond %s'
proxy: "http://127.0.0.1:%d" proxy: "http://127.0.0.1:%d"
`, binaryPath, port, expectedMessage, port) `, simpleResponderPath, port, expectedMessage, port)
var cfg ModelConfig var cfg ModelConfig
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil { if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
+14 -31
View File
@@ -2,10 +2,13 @@ package proxy
import ( import (
"container/ring" "container/ring"
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
"sync" "sync"
"github.com/mostlygeek/llama-swap/event"
) )
type LogLevel int type LogLevel int
@@ -18,7 +21,7 @@ const (
) )
type LogMonitor struct { type LogMonitor struct {
clients map[chan []byte]bool eventbus *event.Dispatcher
mu sync.RWMutex mu sync.RWMutex
buffer *ring.Ring buffer *ring.Ring
bufferMu sync.RWMutex bufferMu sync.RWMutex
@@ -37,11 +40,11 @@ func NewLogMonitor() *LogMonitor {
func NewLogMonitorWriter(stdout io.Writer) *LogMonitor { func NewLogMonitorWriter(stdout io.Writer) *LogMonitor {
return &LogMonitor{ return &LogMonitor{
clients: make(map[chan []byte]bool), eventbus: event.NewDispatcherConfig(1000),
buffer: ring.New(10 * 1024), // keep 10KB of buffered logs buffer: ring.New(10 * 1024), // keep 10KB of buffered logs
stdout: stdout, stdout: stdout,
level: LevelInfo, level: LevelInfo,
prefix: "", prefix: "",
} }
} }
@@ -81,34 +84,14 @@ func (w *LogMonitor) GetHistory() []byte {
return history return history
} }
func (w *LogMonitor) Subscribe() chan []byte { func (w *LogMonitor) OnLogData(callback func(data []byte)) context.CancelFunc {
w.mu.Lock() return event.Subscribe(w.eventbus, func(e LogDataEvent) {
defer w.mu.Unlock() callback(e.Data)
})
ch := make(chan []byte, 100)
w.clients[ch] = true
return ch
}
func (w *LogMonitor) Unsubscribe(ch chan []byte) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.clients, ch)
close(ch)
} }
func (w *LogMonitor) broadcast(msg []byte) { func (w *LogMonitor) broadcast(msg []byte) {
w.mu.RLock() event.Publish(w.eventbus, LogDataEvent{Data: msg})
defer w.mu.RUnlock()
for client := range w.clients {
select {
case client <- msg:
default:
// If client buffer is full, skip
}
}
} }
func (w *LogMonitor) SetPrefix(prefix string) { func (w *LogMonitor) SetPrefix(prefix string) {
+13 -22
View File
@@ -10,38 +10,29 @@ import (
func TestLogMonitor(t *testing.T) { func TestLogMonitor(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard) logMonitor := NewLogMonitorWriter(io.Discard)
// Test subscription // A WaitGroup is used to wait for all the expected writes to complete
client1 := logMonitor.Subscribe() var wg sync.WaitGroup
client2 := logMonitor.Subscribe()
defer logMonitor.Unsubscribe(client1)
defer logMonitor.Unsubscribe(client2)
client1Messages := make([]byte, 0) client1Messages := make([]byte, 0)
client2Messages := make([]byte, 0) client2Messages := make([]byte, 0)
var wg sync.WaitGroup defer logMonitor.OnLogData(func(data []byte) {
wg.Add(1) client1Messages = append(client1Messages, data...)
wg.Done()
})()
go func() { defer logMonitor.OnLogData(func(data []byte) {
defer wg.Done() client2Messages = append(client2Messages, data...)
for { wg.Done()
select { })()
case data := <-client1:
client1Messages = append(client1Messages, data...) wg.Add(6) // 2 x 3 writes
case data := <-client2:
client2Messages = append(client2Messages, data...)
default:
return
}
}
}()
logMonitor.Write([]byte("1")) logMonitor.Write([]byte("1"))
logMonitor.Write([]byte("2")) logMonitor.Write([]byte("2"))
logMonitor.Write([]byte("3")) logMonitor.Write([]byte("3"))
// Wait for the goroutine to finish // wait for all writes to complete
wg.Wait() wg.Wait()
// Check the buffer // Check the buffer
+173
View File
@@ -0,0 +1,173 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
c.Abort()
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
c.Abort()
return
}
writer := &MetricsResponseWriter{
ResponseWriter: c.Writer,
metricsRecorder: &MetricsRecorder{
metricsMonitor: pm.metricsMonitor,
realModelName: realModelName,
isStreaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
startTime: time.Now(),
},
}
c.Writer = writer
c.Next()
rec := writer.metricsRecorder
rec.processBody(writer.body)
}
}
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
isStreaming bool
startTime time.Time
}
// processBody handles response processing after request completes
func (rec *MetricsRecorder) processBody(body []byte) {
if rec.isStreaming {
rec.processStreamingResponse(body)
} else {
rec.processNonStreamingResponse(body)
}
}
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
usage := jsonData.Get("usage")
if !usage.Exists() {
return false
}
// default values
outputTokens := int(jsonData.Get("usage.completion_tokens").Int())
inputTokens := int(jsonData.Get("usage.prompt_tokens").Int())
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(rec.startTime).Milliseconds())
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings := jsonData.Get("timings"); timings.Exists() {
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
}
rec.metricsMonitor.addMetrics(TokenMetrics{
Timestamp: time.Now(),
Model: rec.realModelName,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
})
return true
}
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
// Iterate **backwards** through the lines looking for the data payload with
// usage data
lines := bytes.Split(body, []byte("\n"))
for i := len(lines) - 1; i >= 0; i-- {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
return // short circuit if a metric was recorded
}
}
}
}
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
if len(body) == 0 {
return
}
// Parse JSON to extract usage information
if gjson.ValidBytes(body) {
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
}
}
// MetricsResponseWriter captures the entire response for non-streaming
type MetricsResponseWriter struct {
gin.ResponseWriter
body []byte
metricsRecorder *MetricsRecorder
}
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
if err != nil {
return n, err
}
w.body = append(w.body, b...)
return n, nil
}
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *MetricsResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
+83
View File
@@ -0,0 +1,83 @@
package proxy
import (
"encoding/json"
"sync"
"time"
"github.com/mostlygeek/llama-swap/event"
)
// TokenMetrics represents parsed token statistics from llama-server logs
type TokenMetrics struct {
ID int `json:"id"`
Timestamp time.Time `json:"timestamp"`
Model string `json:"model"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
PromptPerSecond float64 `json:"prompt_per_second"`
TokensPerSecond float64 `json:"tokens_per_second"`
DurationMs int `json:"duration_ms"`
}
// TokenMetricsEvent represents a token metrics event
type TokenMetricsEvent struct {
Metrics TokenMetrics
}
func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go
}
// MetricsMonitor parses llama-server output for token statistics
type MetricsMonitor struct {
mu sync.RWMutex
metrics []TokenMetrics
maxMetrics int
nextID int
}
func NewMetricsMonitor(config *Config) *MetricsMonitor {
maxMetrics := config.MetricsMaxInMemory
if maxMetrics <= 0 {
maxMetrics = 1000 // Default fallback
}
mp := &MetricsMonitor{
maxMetrics: maxMetrics,
}
return mp
}
// addMetrics adds a new metric to the collection and publishes an event
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
mp.mu.Lock()
defer mp.mu.Unlock()
metric.ID = mp.nextID
mp.nextID++
mp.metrics = append(mp.metrics, metric)
if len(mp.metrics) > mp.maxMetrics {
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
}
event.Emit(TokenMetricsEvent{Metrics: metric})
}
// GetMetrics returns a copy of the current metrics
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
mp.mu.RLock()
defer mp.mu.RUnlock()
result := make([]TokenMetrics, len(mp.metrics))
copy(result, mp.metrics)
return result
}
// GetMetricsJSON returns metrics as JSON
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
return json.Marshal(mp.metrics)
}
+6 -3
View File
@@ -13,6 +13,8 @@ import (
"sync" "sync"
"syscall" "syscall"
"time" "time"
"github.com/mostlygeek/llama-swap/event"
) )
type ProcessState string type ProcessState string
@@ -127,6 +129,7 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
p.state = newState p.state = newState
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState) p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
return p.state, nil return p.state, nil
} }
@@ -209,11 +212,11 @@ func (p *Process) start() error {
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil { if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.state = StateStopped // force it into a stopped state p.state = StateStopped // force it into a stopped state
return fmt.Errorf( return fmt.Errorf(
"failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v", "failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
err, curState, swapErr, strings.Join(args, " "), err, curState, swapErr,
) )
} }
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed for command '%s': %v", strings.Join(args, " "), err)
} }
// Capture the exit error for later signalling // Capture the exit error for later signalling
+1 -1
View File
@@ -107,7 +107,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
w = httptest.NewRecorder() w = httptest.NewRecorder()
process.ProxyRequest(w, req) process.ProxyRequest(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code) assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "start() failed: ") assert.Contains(t, w.Body.String(), "start() failed for command 'nonexistent-command':")
} }
func TestProcess_UnloadAfterTTL(t *testing.T) { func TestProcess_UnloadAfterTTL(t *testing.T) {
+94 -22
View File
@@ -2,18 +2,20 @@ package proxy
import ( import (
"bytes" "bytes"
"encoding/json" "context"
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"os" "os"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -33,7 +35,13 @@ type ProxyManager struct {
upstreamLogger *LogMonitor upstreamLogger *LogMonitor
muxLogger *LogMonitor muxLogger *LogMonitor
metricsMonitor *MetricsMonitor
processGroups map[string]*ProcessGroup processGroups map[string]*ProcessGroup
// shutdown signaling
shutdownCtx context.Context
shutdownCancel context.CancelFunc
} }
func New(config Config) *ProxyManager { func New(config Config) *ProxyManager {
@@ -64,6 +72,8 @@ func New(config Config) *ProxyManager {
upstreamLogger.SetLogLevel(LevelInfo) upstreamLogger.SetLogLevel(LevelInfo)
} }
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
pm := &ProxyManager{ pm := &ProxyManager{
config: config, config: config,
ginEngine: gin.New(), ginEngine: gin.New(),
@@ -72,7 +82,12 @@ func New(config Config) *ProxyManager {
muxLogger: stdoutLogger, muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger, upstreamLogger: upstreamLogger,
metricsMonitor: NewMetricsMonitor(&config),
processGroups: make(map[string]*ProcessGroup), processGroups: make(map[string]*ProcessGroup),
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
} }
// create the process groups // create the process groups
@@ -82,6 +97,35 @@ func New(config Config) *ProxyManager {
} }
pm.setupGinEngine() pm.setupGinEngine()
// run any startup hooks
if len(config.Hooks.OnStartup.Preload) > 0 {
// do it in the background, don't block startup -- not sure if good idea yet
go func() {
discardWriter := &DiscardWriter{}
for _, realModelName := range config.Hooks.OnStartup.Preload {
proxyLogger.Infof("Preloading model: %s", realModelName)
processGroup, _, err := pm.swapProcessGroup(realModelName)
if err != nil {
event.Emit(ModelPreloadedEvent{
ModelName: realModelName,
Success: false,
})
proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err)
continue
} else {
req, _ := http.NewRequest("GET", "/", nil)
processGroup.ProxyRequest(realModelName, discardWriter, req)
event.Emit(ModelPreloadedEvent{
ModelName: realModelName,
Success: true,
})
}
}
}()
}
return pm return pm
} }
@@ -140,14 +184,18 @@ func (pm *ProxyManager) setupGinEngine() {
c.Next() c.Next()
}) })
mm := MetricsMiddleware(pm)
// Set up routes using the Gin engine // Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12 // Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
// Support embeddings // Support embeddings
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
@@ -158,9 +206,7 @@ func (pm *ProxyManager) setupGinEngine() {
// in proxymanager_loghandlers.go // in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
/** /**
* User Interface Endpoints * User Interface Endpoints
@@ -176,6 +222,9 @@ func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) { pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil { if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
@@ -262,6 +311,7 @@ func (pm *ProxyManager) Shutdown() {
}(processGroup) }(processGroup)
} }
wg.Wait() wg.Wait()
pm.shutdownCancel()
} }
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) { func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
@@ -289,32 +339,48 @@ func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup,
} }
func (pm *ProxyManager) listModelsHandler(c *gin.Context) { func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{} data := make([]gin.H, 0, len(pm.config.Models))
createdTime := time.Now().Unix()
for id, modelConfig := range pm.config.Models { for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted { if modelConfig.Unlisted {
continue continue
} }
data = append(data, map[string]interface{}{ record := gin.H{
"id": id, "id": id,
"object": "model", "object": "model",
"created": time.Now().Unix(), "created": createdTime,
"owned_by": "llama-swap", "owned_by": "llama-swap",
}) }
if name := strings.TrimSpace(modelConfig.Name); name != "" {
record["name"] = name
}
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
record["description"] = desc
}
data = append(data, record)
} }
// Set the Content-Type header to application/json // Sort by the "id" key
c.Header("Content-Type", "application/json") sort.Slice(data, func(i, j int) bool {
si, _ := data[i]["id"].(string)
sj, _ := data[j]["id"].(string)
return si < sj
})
if origin := c.Request.Header.Get("Origin"); origin != "" { // Set CORS headers if origin exists
if origin := c.GetHeader("Origin"); origin != "" {
c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Origin", origin)
} }
// Encode the data as JSON and write it to the response writer // Use gin's JSON method which handles content-type and encoding
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil { c.JSON(http.StatusOK, gin.H{
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error())) "object": "list",
return "data": data,
} })
} }
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
@@ -325,7 +391,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return return
} }
processGroup, _, err := pm.swapProcessGroup(requestedModel) processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
@@ -333,7 +399,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// rewrite the path // rewrite the path
c.Request.URL.Path = c.Param("upstreamPath") c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request) processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
} }
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
@@ -349,7 +415,13 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return return
} }
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}
processGroup, _, err := pm.swapProcessGroup(realModelName)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
+121 -18
View File
@@ -1,25 +1,30 @@
package proxy package proxy
import ( import (
"context"
"encoding/json"
"net/http" "net/http"
"sort" "sort"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
) )
type Model struct { type Model struct {
Id string `json:"id"` Id string `json:"id"`
State string `json:"state"` Name string `json:"name"`
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
} }
func addApiHandlers(pm *ProxyManager) { func addApiHandlers(pm *ProxyManager) {
// Add API endpoints for React to consume // Add API endpoints for React to consume
apiGroup := pm.ginEngine.Group("/api") apiGroup := pm.ginEngine.Group("/api")
{ {
apiGroup.GET("/models", pm.apiListModels)
apiGroup.GET("/modelsSSE", pm.apiListModelsSSE)
apiGroup.POST("/models/unload", pm.apiUnloadAllModels) apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics)
} }
} }
@@ -65,37 +70,135 @@ func (pm *ProxyManager) getModelStatus() []Model {
} }
} }
models = append(models, Model{ models = append(models, Model{
Id: modelID, Id: modelID,
State: state, Name: pm.config.Models[modelID].Name,
Description: pm.config.Models[modelID].Description,
State: state,
Unlisted: pm.config.Models[modelID].Unlisted,
}) })
} }
return models return models
} }
func (pm *ProxyManager) apiListModels(c *gin.Context) { type messageType string
c.JSON(http.StatusOK, pm.getModelStatus())
const (
msgTypeModelStatus messageType = "modelStatus"
msgTypeLogData messageType = "logData"
msgTypeMetrics messageType = "metrics"
)
type messageEnvelope struct {
Type messageType `json:"type"`
Data string `json:"data"`
} }
// stream the models as a SSE // sends a stream of different message types that happen on the server
func (pm *ProxyManager) apiListModelsSSE(c *gin.Context) { func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
notify := c.Request.Context().Done() sendBuffer := make(chan messageEnvelope, 25)
ctx, cancel := context.WithCancel(c.Request.Context())
sendModels := func() {
data, err := json.Marshal(pm.getModelStatus())
if err == nil {
msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)}
select {
case sendBuffer <- msg:
case <-ctx.Done():
return
default:
}
}
}
sendLogData := func(source string, data []byte) {
data, err := json.Marshal(gin.H{
"source": source,
"data": string(data),
})
if err == nil {
select {
case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}:
case <-ctx.Done():
return
default:
}
}
}
sendMetrics := func(metrics TokenMetrics) {
jsonData, err := json.Marshal(metrics)
if err == nil {
select {
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
case <-ctx.Done():
return
default:
}
}
}
/**
* Send updated models list
*/
defer event.On(func(e ProcessStateChangeEvent) {
sendModels()
})()
defer event.On(func(e ConfigFileChangedEvent) {
sendModels()
})()
/**
* Send Log data
*/
defer pm.proxyLogger.OnLogData(func(data []byte) {
sendLogData("proxy", data)
})()
defer pm.upstreamLogger.OnLogData(func(data []byte) {
sendLogData("upstream", data)
})()
/**
* Send Metrics data
*/
defer event.On(func(e TokenMetricsEvent) {
sendMetrics(e.Metrics)
})()
// send initial batch of data
sendLogData("proxy", pm.proxyLogger.GetHistory())
sendLogData("upstream", pm.upstreamLogger.GetHistory())
sendModels()
for _, metrics := range pm.metricsMonitor.GetMetrics() {
sendMetrics(metrics)
}
// Stream new events
for { for {
select { select {
case <-notify: case <-c.Request.Context().Done():
cancel()
return return
default: case <-pm.shutdownCtx.Done():
models := pm.getModelStatus() cancel()
c.SSEvent("message", models) return
case msg := <-sendBuffer:
c.SSEvent("message", msg)
c.Writer.Flush() c.Writer.Flush()
<-time.After(1000 * time.Millisecond)
} }
} }
} }
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
return
}
c.Data(http.StatusOK, "application/json", jsonData)
}
+20 -51
View File
@@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@@ -34,10 +35,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
c.String(http.StatusBadRequest, err.Error()) c.String(http.StatusBadRequest, err.Error())
return return
} }
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done()
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported")) c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported"))
@@ -55,57 +53,28 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) {
} }
} }
// Stream new logs sendChan := make(chan []byte, 10)
ctx, cancel := context.WithCancel(c.Request.Context())
defer logger.OnLogData(func(data []byte) {
select {
case sendChan <- data:
case <-ctx.Done():
return
default:
}
})()
for { for {
select { select {
case msg := <-ch: case <-c.Request.Context().Done():
_, err := c.Writer.Write(msg) cancel()
if err != nil { return
// just break the loop if we can't write for some reason case <-pm.shutdownCtx.Done():
return cancel()
} return
case data := <-sendChan:
c.Writer.Write(data)
flusher.Flush() flusher.Flush()
case <-notify:
return
}
}
}
func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Content-Type-Options", "nosniff")
logMonitorId := c.Param("logMonitorID")
logger, err := pm.getLogger(logMonitorId)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
ch := logger.Subscribe()
defer logger.Unsubscribe(ch)
notify := c.Request.Context().Done()
// Send history first if not skipped
_, skipHistory := c.GetQuery("no-history")
if !skipHistory {
history := logger.GetHistory()
if len(history) != 0 {
c.SSEvent("message", string(history))
c.Writer.Flush()
}
}
// Stream new logs
for {
select {
case msg := <-ch:
c.SSEvent("message", string(msg))
c.Writer.Flush()
case <-notify:
return
} }
} }
} }
+253 -18
View File
@@ -9,10 +9,12 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv" "strconv"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/mostlygeek/llama-swap/event"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -165,9 +167,11 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
} }
mu.Lock() mu.Lock()
var response map[string]string var response map[string]interface{}
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"] result, ok := response["responseMessage"].(string)
assert.Equal(t, ok, true)
results[key] = result
mu.Unlock() mu.Unlock()
}(key) }(key)
@@ -183,11 +187,20 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
} }
func TestProxyManager_ListModelsHandler(t *testing.T) { func TestProxyManager_ListModelsHandler(t *testing.T) {
model1Config := getTestSimpleResponderConfig("model1")
model1Config.Name = "Model 1"
model1Config.Description = "Model 1 description is used for testing"
model2Config := getTestSimpleResponderConfig("model2")
model2Config.Name = " " // empty whitespace only strings will get ignored
model2Config.Description = " "
config := Config{ config := Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": model1Config,
"model2": getTestSimpleResponderConfig("model2"), "model2": model2Config,
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -213,6 +226,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
var response struct { var response struct {
Data []map[string]interface{} `json:"data"` Data []map[string]interface{} `json:"data"`
} }
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err) t.Fatalf("Failed to parse JSON response: %v", err)
} }
@@ -227,6 +241,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
"model3": {}, "model3": {},
} }
// make all models
for _, model := range response.Data { for _, model := range response.Data {
modelID, ok := model["id"].(string) modelID, ok := model["id"].(string)
assert.True(t, ok, "model ID should be a string") assert.True(t, ok, "model ID should be a string")
@@ -245,12 +260,72 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
ownedBy, ok := model["owned_by"].(string) ownedBy, ok := model["owned_by"].(string)
assert.True(t, ok, "owned_by should be a string") assert.True(t, ok, "owned_by should be a string")
assert.Equal(t, "llama-swap", ownedBy) assert.Equal(t, "llama-swap", ownedBy)
// check for optional name and description
if modelID == "model1" {
name, ok := model["name"].(string)
assert.True(t, ok, "name should be a string")
assert.Equal(t, "Model 1", name)
description, ok := model["description"].(string)
assert.True(t, ok, "description should be a string")
assert.Equal(t, "Model 1 description is used for testing", description)
} else {
_, exists := model["name"]
assert.False(t, exists, "unexpected name field for model: %s", modelID)
_, exists = model["description"]
assert.False(t, exists, "unexpected description field for model: %s", modelID)
}
} }
// Ensure all expected models were returned // Ensure all expected models were returned
assert.Empty(t, expectedModels, "not all expected models were returned") assert.Empty(t, expectedModels, "not all expected models were returned")
} }
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
// Intentionally add models in non-sorted order and with an unlisted model
config := Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"zeta": getTestSimpleResponderConfig("zeta"),
"alpha": getTestSimpleResponderConfig("alpha"),
"beta": getTestSimpleResponderConfig("beta"),
"hidden": func() ModelConfig {
mc := getTestSimpleResponderConfig("hidden")
mc.Unlisted = true
return mc
}(),
},
LogLevel: "error",
}
proxy := New(config)
// Request models list
req := httptest.NewRequest("GET", "/v1/models", nil)
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err)
}
// We expect only the listed models in sorted order by id
expectedOrder := []string{"alpha", "beta", "zeta"}
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
got := make([]string, 0, len(response.Data))
for _, m := range response.Data {
id, _ := m["id"].(string)
got = append(got, id)
}
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
}
}
func TestProxyManager_Shutdown(t *testing.T) { func TestProxyManager_Shutdown(t *testing.T) {
// make broken model configurations // make broken model configurations
model1Config := getTestSimpleResponderConfigPort("model1", 9991) model1Config := getTestSimpleResponderConfigPort("model1", 9991)
@@ -583,21 +658,34 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
} }
func TestProxyManager_Upstream(t *testing.T) { func TestProxyManager_Upstream(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ configStr := fmt.Sprintf(`
HealthCheckTimeout: 15, logLevel: error
Models: map[string]ModelConfig{ models:
"model1": getTestSimpleResponderConfig("model1"), model1:
}, cmd: %s -port ${PORT} -silent -respond model1
LogLevel: "error", aliases: [model-alias]
}) `, getSimpleResponderPath())
config, err := LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/upstream/model1/test", nil) t.Run("main model name", func(t *testing.T) {
rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
proxy.ServeHTTP(rec, req) rec := httptest.NewRecorder()
assert.Equal(t, http.StatusOK, rec.Code) proxy.ServeHTTP(rec, req)
assert.Equal(t, "model1", rec.Body.String()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
})
t.Run("model alias", func(t *testing.T) {
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
})
} }
func TestProxyManager_ChatContentLength(t *testing.T) { func TestProxyManager_ChatContentLength(t *testing.T) {
@@ -618,7 +706,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string var response map[string]interface{}
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"]) assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"]) assert.Equal(t, "model1", response["responseMessage"])
@@ -646,7 +734,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string var response map[string]interface{}
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// `temperature` and `stream` are gone but model remains // `temperature` and `stream` are gone but model remains
@@ -657,3 +745,150 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
// assert.Equal(t, "abc", response["y_param"]) // assert.Equal(t, "abc", response["y_param"])
// t.Logf("%v", response) // t.Logf("%v", response)
} }
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a non-streaming request
reqBody := `{"model":"model1", "stream": false}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a streaming request
reqBody := `{"model":"model1", "stream": true}`
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_HealthEndpoint(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
}
func TestProxyManager_StartupHooks(t *testing.T) {
// using real YAML as the configuration has gotten more complex
// is the right approach as LoadConfigFromReader() does a lot more
// than parse YAML now. Eventually migrate all tests to use this approach
configStr := strings.Replace(`
logLevel: error
hooks:
on_startup:
preload:
- model1
- model2
groups:
preloadTestGroup:
swap: false
members:
- model1
- model2
models:
model1:
cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model1
model2:
cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model2
`, "${simpleresponderpath}", simpleResponderPath, -1)
// Create a test model configuration
config, err := LoadConfigFromReader(strings.NewReader(configStr))
if !assert.NoError(t, err, "Invalid configuration") {
return
}
preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events
unsub := event.On(func(e ModelPreloadedEvent) {
preloadChan <- e
})
defer unsub()
// Create the proxy which should trigger preloading
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
for i := 0; i < 2; i++ {
select {
case <-preloadChan:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for models to preload")
}
}
// make sure they are both loaded
_, foundGroup := proxy.processGroups["preloadTestGroup"]
if !assert.True(t, foundGroup, "preloadTestGroup should exist") {
return
}
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model1"].CurrentState())
assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model2"].CurrentState())
}
+21
View File
@@ -12,6 +12,8 @@
"@tanstack/react-query": "^5.80.6", "@tanstack/react-query": "^5.80.6",
"react": "^19.1.0", "react": "^19.1.0",
"react-dom": "^19.1.0", "react-dom": "^19.1.0",
"react-icons": "^5.5.0",
"react-resizable-panels": "^3.0.4",
"react-router-dom": "^7.6.2", "react-router-dom": "^7.6.2",
"tailwindcss": "^4.1.8" "tailwindcss": "^4.1.8"
}, },
@@ -3460,6 +3462,15 @@
"react": "^19.1.0" "react": "^19.1.0"
} }
}, },
"node_modules/react-icons": {
"version": "5.5.0",
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-5.5.0.tgz",
"integrity": "sha512-MEFcXdkP3dLo8uumGI5xN3lDFNsRtrjbOEKDLD7yv76v4wpnEq2Lt2qeHaQOr34I/wPN3s3+N08WkQ+CW37Xiw==",
"license": "MIT",
"peerDependencies": {
"react": "*"
}
},
"node_modules/react-refresh": { "node_modules/react-refresh": {
"version": "0.17.0", "version": "0.17.0",
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.17.0.tgz", "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.17.0.tgz",
@@ -3470,6 +3481,16 @@
"node": ">=0.10.0" "node": ">=0.10.0"
} }
}, },
"node_modules/react-resizable-panels": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/react-resizable-panels/-/react-resizable-panels-3.0.4.tgz",
"integrity": "sha512-8Y4KNgV94XhUvI2LeByyPIjoUJb71M/0hyhtzkHaqpVHs+ZQs8b627HmzyhmVYi3C9YP6R+XD1KmG7hHjEZXFQ==",
"license": "MIT",
"peerDependencies": {
"react": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc",
"react-dom": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc"
}
},
"node_modules/react-router": { "node_modules/react-router": {
"version": "7.6.2", "version": "7.6.2",
"resolved": "https://registry.npmjs.org/react-router/-/react-router-7.6.2.tgz", "resolved": "https://registry.npmjs.org/react-router/-/react-router-7.6.2.tgz",
+3 -1
View File
@@ -14,6 +14,8 @@
"@tanstack/react-query": "^5.80.6", "@tanstack/react-query": "^5.80.6",
"react": "^19.1.0", "react": "^19.1.0",
"react-dom": "^19.1.0", "react-dom": "^19.1.0",
"react-icons": "^5.5.0",
"react-resizable-panels": "^3.0.4",
"react-router-dom": "^7.6.2", "react-router-dom": "^7.6.2",
"tailwindcss": "^4.1.8" "tailwindcss": "^4.1.8"
}, },
@@ -30,4 +32,4 @@
"typescript-eslint": "^8.30.1", "typescript-eslint": "^8.30.1",
"vite": "^6.3.5" "vite": "^6.3.5"
} }
} }
+17 -9
View File
@@ -3,17 +3,20 @@ import { useTheme } from "./contexts/ThemeProvider";
import { APIProvider } from "./contexts/APIProvider"; import { APIProvider } from "./contexts/APIProvider";
import LogViewerPage from "./pages/LogViewer"; import LogViewerPage from "./pages/LogViewer";
import ModelPage from "./pages/Models"; import ModelPage from "./pages/Models";
import ActivityPage from "./pages/Activity";
import { RiSunFill, RiMoonFill } from "react-icons/ri";
function App() { function App() {
const theme = useTheme(); const { isNarrow, toggleTheme, isDarkMode } = useTheme();
return ( return (
<Router basename="/ui/"> <Router basename="/ui/">
<APIProvider> <APIProvider>
<div> <div className="flex flex-col h-screen">
<nav className="bg-surface border-b border-border p-4"> <nav className="bg-surface border-b border-border p-2 h-[75px]">
<div className="flex items-center justify-between mx-auto px-4"> <div className="flex items-center justify-between mx-auto px-4 h-full">
<h1>llama-swap</h1> {!isNarrow && <h1 className="flex items-center p-0">llama-swap</h1>}
<div className="flex space-x-4"> <div className="flex items-center space-x-4">
<NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}> <NavLink to="/" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Logs Logs
</NavLink> </NavLink>
@@ -21,17 +24,22 @@ function App() {
<NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}> <NavLink to="/models" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Models Models
</NavLink> </NavLink>
<button className="btn btn--sm" onClick={theme.toggleTheme}>
{theme.isDarkMode ? "🌙" : "☀️"} <NavLink to="/activity" className={({ isActive }) => (isActive ? "navlink active" : "navlink")}>
Activity
</NavLink>
<button className="" onClick={toggleTheme}>
{isDarkMode ? <RiMoonFill /> : <RiSunFill />}
</button> </button>
</div> </div>
</div> </div>
</nav> </nav>
<main className="mx-auto py-4 px-4"> <main className="flex-1 overflow-auto p-4">
<Routes> <Routes>
<Route path="/" element={<LogViewerPage />} /> <Route path="/" element={<LogViewerPage />} />
<Route path="/models" element={<ModelPage />} /> <Route path="/models" element={<ModelPage />} />
<Route path="/activity" element={<ActivityPage />} />
<Route path="*" element={<Navigate to="/" replace />} /> <Route path="*" element={<Navigate to="/" replace />} />
</Routes> </Routes>
</main> </main>
+98 -115
View File
@@ -6,6 +6,9 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
export interface Model { export interface Model {
id: string; id: string;
state: ModelStatus; state: ModelStatus;
name: string;
description: string;
unlisted: boolean;
} }
interface APIProviderType { interface APIProviderType {
@@ -13,26 +16,45 @@ interface APIProviderType {
listModels: () => Promise<Model[]>; listModels: () => Promise<Model[]>;
unloadAllModels: () => Promise<void>; unloadAllModels: () => Promise<void>;
loadModel: (model: string) => Promise<void>; loadModel: (model: string) => Promise<void>;
enableProxyLogs: (enabled: boolean) => void; enableAPIEvents: (enabled: boolean) => void;
enableUpstreamLogs: (enabled: boolean) => void;
enableModelUpdates: (enabled: boolean) => void;
proxyLogs: string; proxyLogs: string;
upstreamLogs: string; upstreamLogs: string;
metrics: Metrics[];
}
interface Metrics {
id: number;
timestamp: string;
model: string;
input_tokens: number;
output_tokens: number;
prompt_per_second: number;
tokens_per_second: number;
duration_ms: number;
}
interface LogData {
source: "upstream" | "proxy";
data: string;
}
interface APIEventEnvelope {
type: "modelStatus" | "logData" | "metrics";
data: string;
} }
const APIContext = createContext<APIProviderType | undefined>(undefined); const APIContext = createContext<APIProviderType | undefined>(undefined);
type APIProviderProps = { type APIProviderProps = {
children: ReactNode; children: ReactNode;
autoStartAPIEvents?: boolean;
}; };
export function APIProvider({ children }: APIProviderProps) { export function APIProvider({ children, autoStartAPIEvents = true }: APIProviderProps) {
const [proxyLogs, setProxyLogs] = useState(""); const [proxyLogs, setProxyLogs] = useState("");
const [upstreamLogs, setUpstreamLogs] = useState(""); const [upstreamLogs, setUpstreamLogs] = useState("");
const proxyEventSource = useRef<EventSource | null>(null); const [metrics, setMetrics] = useState<Metrics[]>([]);
const upstreamEventSource = useRef<EventSource | null>(null); const apiEventSource = useRef<EventSource | null>(null);
const [models, setModels] = useState<Model[]>([]); const [models, setModels] = useState<Model[]>([]);
const modelStatusEventSource = useRef<EventSource | null>(null);
const appendLog = useCallback((newData: string, setter: React.Dispatch<React.SetStateAction<string>>) => { const appendLog = useCallback((newData: string, setter: React.Dispatch<React.SetStateAction<string>>) => {
setter((prev) => { setter((prev) => {
@@ -41,112 +63,84 @@ export function APIProvider({ children }: APIProviderProps) {
}); });
}, []); }, []);
const handleProxyMessage = useCallback( const enableAPIEvents = useCallback((enabled: boolean) => {
(e: MessageEvent) => { if (!enabled) {
appendLog(e.data, setProxyLogs); apiEventSource.current?.close();
}, apiEventSource.current = null;
[proxyLogs, appendLog] setMetrics([]);
); return;
}
const handleUpstreamMessage = useCallback( let retryCount = 0;
(e: MessageEvent) => { const initialDelay = 1000; // 1 second
appendLog(e.data, setUpstreamLogs);
},
[appendLog]
);
const enableProxyLogs = useCallback( const connect = () => {
(enabled: boolean) => { const eventSource = new EventSource("/api/events");
if (enabled) {
let retryCount = 0;
const maxRetries = 3;
const initialDelay = 1000; // 1 second
const connect = () => { eventSource.onmessage = (e: MessageEvent) => {
const eventSource = new EventSource("/logs/streamSSE/proxy"); try {
const message = JSON.parse(e.data) as APIEventEnvelope;
switch (message.type) {
case "modelStatus":
{
const models = JSON.parse(message.data) as Model[];
eventSource.onmessage = handleProxyMessage; // sort models by name and id
eventSource.onerror = () => { models.sort((a, b) => {
eventSource.close(); return (a.name + a.id).localeCompare(b.name + b.id);
if (retryCount < maxRetries) { });
retryCount++;
const delay = initialDelay * Math.pow(2, retryCount - 1);
setTimeout(connect, delay);
}
};
proxyEventSource.current = eventSource; setModels(models);
}; }
break;
connect(); case "logData":
} else { const logData = JSON.parse(message.data) as LogData;
proxyEventSource.current?.close(); switch (logData.source) {
proxyEventSource.current = null; case "proxy":
} appendLog(logData.data, setProxyLogs);
}, break;
[handleProxyMessage] case "upstream":
); appendLog(logData.data, setUpstreamLogs);
break;
}
break;
const enableUpstreamLogs = useCallback( case "metrics":
(enabled: boolean) => { {
if (enabled) { const newMetric = JSON.parse(message.data) as Metrics;
let retryCount = 0; setMetrics((prevMetrics) => {
const maxRetries = 3; return [newMetric, ...prevMetrics];
const initialDelay = 1000; // 1 second });
}
const connect = () => { break;
const eventSource = new EventSource("/logs/streamSSE/upstream");
eventSource.onmessage = handleUpstreamMessage;
eventSource.onerror = () => {
eventSource.close();
if (retryCount < maxRetries) {
retryCount++;
const delay = initialDelay * Math.pow(2, retryCount - 1);
setTimeout(connect, delay);
}
};
upstreamEventSource.current = eventSource;
};
connect();
} else {
upstreamEventSource.current?.close();
upstreamEventSource.current = null;
}
},
[handleUpstreamMessage]
);
const enableModelUpdates = useCallback(
(enabled: boolean) => {
if (enabled) {
const eventSource = new EventSource("/api/modelsSSE");
eventSource.onmessage = (e: MessageEvent) => {
try {
const models = JSON.parse(e.data) as Model[];
setModels(models);
} catch (e) {
console.error(e);
} }
}; } catch (err) {
modelStatusEventSource.current = eventSource; console.error(e.data, err);
} else { }
modelStatusEventSource.current?.close(); };
modelStatusEventSource.current = null; eventSource.onerror = () => {
} eventSource.close();
}, retryCount++;
[setModels] const delay = Math.min(initialDelay * Math.pow(2, retryCount - 1), 5000);
); setTimeout(connect, delay);
};
apiEventSource.current = eventSource;
};
connect();
}, []);
useEffect(() => { useEffect(() => {
if (autoStartAPIEvents) {
enableAPIEvents(true);
}
return () => { return () => {
proxyEventSource.current?.close(); enableAPIEvents(false);
upstreamEventSource.current?.close();
modelStatusEventSource.current?.close();
}; };
}, []); }, [enableAPIEvents, autoStartAPIEvents]);
const listModels = useCallback(async (): Promise<Model[]> => { const listModels = useCallback(async (): Promise<Model[]> => {
try { try {
@@ -196,23 +190,12 @@ export function APIProvider({ children }: APIProviderProps) {
listModels, listModels,
unloadAllModels, unloadAllModels,
loadModel, loadModel,
enableProxyLogs, enableAPIEvents,
enableUpstreamLogs,
enableModelUpdates,
proxyLogs, proxyLogs,
upstreamLogs, upstreamLogs,
metrics,
}), }),
[ [models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs, metrics]
models,
listModels,
unloadAllModels,
loadModel,
enableProxyLogs,
enableUpstreamLogs,
enableModelUpdates,
proxyLogs,
upstreamLogs,
]
); );
return <APIContext.Provider value={value}>{children}</APIContext.Provider>; return <APIContext.Provider value={value}>{children}</APIContext.Provider>;
+37 -2
View File
@@ -1,8 +1,11 @@
import { createContext, useContext, useEffect, type ReactNode } from "react"; import { createContext, useContext, useEffect, type ReactNode, useMemo, useState } from "react";
import { usePersistentState } from "../hooks/usePersistentState"; import { usePersistentState } from "../hooks/usePersistentState";
type ScreenWidth = "xs" | "sm" | "md" | "lg" | "xl" | "2xl";
type ThemeContextType = { type ThemeContextType = {
isDarkMode: boolean; isDarkMode: boolean;
screenWidth: ScreenWidth;
isNarrow: boolean;
toggleTheme: () => void; toggleTheme: () => void;
}; };
@@ -14,14 +17,46 @@ type ThemeProviderProps = {
export function ThemeProvider({ children }: ThemeProviderProps) { export function ThemeProvider({ children }: ThemeProviderProps) {
const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false); const [isDarkMode, setIsDarkMode] = usePersistentState<boolean>("theme", false);
const [screenWidth, setScreenWidth] = useState<ScreenWidth>("md"); // Default to md
// matches tailwind classes
// https://tailwindcss.com/docs/responsive-design
useEffect(() => {
const checkInnerWidth = () => {
const innerWidth = window.innerWidth;
if (innerWidth < 640) {
setScreenWidth("xs");
} else if (innerWidth < 768) {
setScreenWidth("sm");
} else if (innerWidth < 1024) {
setScreenWidth("md");
} else if (innerWidth < 1280) {
setScreenWidth("lg");
} else if (innerWidth < 1536) {
setScreenWidth("xl");
} else {
setScreenWidth("2xl");
}
};
checkInnerWidth();
window.addEventListener("resize", checkInnerWidth);
return () => window.removeEventListener("resize", checkInnerWidth);
}, []);
useEffect(() => { useEffect(() => {
document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light"); document.documentElement.setAttribute("data-theme", isDarkMode ? "dark" : "light");
}, [isDarkMode]); }, [isDarkMode]);
const toggleTheme = () => setIsDarkMode((prev) => !prev); const toggleTheme = () => setIsDarkMode((prev) => !prev);
const isNarrow = useMemo(() => {
return screenWidth === "xs" || screenWidth === "sm" || screenWidth === "md";
}, [screenWidth]);
return <ThemeContext.Provider value={{ isDarkMode, toggleTheme }}>{children}</ThemeContext.Provider>; return (
<ThemeContext.Provider value={{ isDarkMode, toggleTheme, screenWidth, isNarrow }}>{children}</ThemeContext.Provider>
);
} }
export function useTheme(): ThemeContextType { export function useTheme(): ThemeContextType {
+79
View File
@@ -0,0 +1,79 @@
import { useState, useEffect } from "react";
import { useAPI } from "../contexts/APIProvider";
const formatTimestamp = (timestamp: string): string => {
return new Date(timestamp).toLocaleString();
};
const formatSpeed = (speed: number): string => {
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
};
const formatDuration = (ms: number): string => {
return (ms / 1000).toFixed(2) + "s";
};
const ActivityPage = () => {
const { metrics } = useAPI();
const [error, setError] = useState<string | null>(null);
useEffect(() => {
if (metrics.length > 0) {
setError(null);
}
}, [metrics]);
if (error) {
return (
<div className="p-6">
<h1 className="text-2xl font-bold mb-4">Activity</h1>
<div className="bg-red-50 border border-red-200 rounded-md p-4">
<p className="text-red-800">{error}</p>
</div>
</div>
);
}
return (
<div className="p-6">
<h1 className="text-2xl font-bold mb-4">Activity</h1>
{metrics.length === 0 ? (
<div className="text-center py-8">
<p className="text-gray-600">No metrics data available</p>
</div>
) : (
<div className="overflow-x-auto">
<table className="min-w-full divide-y">
<thead>
<tr>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Timestamp</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Model</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Input Tokens</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Output Tokens</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Prompt Processing</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Generation Speed</th>
<th className="px-6 py-3 text-left text-xs font-medium uppercase tracking-wider">Duration</th>
</tr>
</thead>
<tbody className="divide-y">
{metrics.map((metric, index) => (
<tr key={`${metric.id}-${index}`}>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatTimestamp(metric.timestamp)}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.model}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.input_tokens.toLocaleString()}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{metric.output_tokens.toLocaleString()}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.prompt_per_second)}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatSpeed(metric.tokens_per_second)}</td>
<td className="px-6 py-4 whitespace-nowrap text-sm">{formatDuration(metric.duration_ms)}</td>
</tr>
))}
</tbody>
</table>
</div>
)}
</div>
);
};
export default ActivityPage;
+69 -67
View File
@@ -1,24 +1,38 @@
import { useState, useEffect, useRef, useMemo, useCallback } from "react"; import { useState, useEffect, useRef, useMemo, useCallback } from "react";
import { useAPI } from "../contexts/APIProvider"; import { useAPI } from "../contexts/APIProvider";
import { usePersistentState } from "../hooks/usePersistentState"; import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import {
RiTextWrap,
RiAlignJustify,
RiFontSize,
RiMenuSearchLine,
RiMenuSearchFill,
RiCloseCircleFill,
} from "react-icons/ri";
import { useTheme } from "../contexts/ThemeProvider";
const LogViewer = () => { const LogViewer = () => {
const { proxyLogs, upstreamLogs, enableProxyLogs, enableUpstreamLogs } = useAPI(); const { proxyLogs, upstreamLogs } = useAPI();
const { isNarrow } = useTheme();
useEffect(() => { const direction = isNarrow ? "vertical" : "horizontal";
enableProxyLogs(true);
enableUpstreamLogs(true);
return () => {
enableProxyLogs(false);
enableUpstreamLogs(false);
};
}, []);
return ( return (
<div className="flex flex-col gap-5"> <PanelGroup direction={direction} className="gap-2" autoSaveId="logviewer-panel-group">
<LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} /> <Panel id="proxy" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} /> <LogPanel id="proxy" title="Proxy Logs" logData={proxyLogs} />
</div> </Panel>
<PanelResizeHandle
className={
direction === "horizontal"
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
}
/>
<Panel id="upstream" defaultSize={50} minSize={5} maxSize={100} collapsible={true}>
<LogPanel id="upstream" title="Upstream Logs" logData={upstreamLogs} />
</Panel>
</PanelGroup>
); );
}; };
@@ -26,20 +40,15 @@ interface LogPanelProps {
id: string; id: string;
title: string; title: string;
logData: string; logData: string;
className?: string;
} }
export const LogPanel = ({ id, title, logData }: LogPanelProps) => {
export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
const [filterRegex, setFilterRegex] = useState(""); const [filterRegex, setFilterRegex] = useState("");
const [panelState, setPanelState] = usePersistentState<"hide" | "small" | "max">(
`logPanel-${id}-panelState`,
"small"
);
const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">( const [fontSize, setFontSize] = usePersistentState<"xxs" | "xs" | "small" | "normal">(
`logPanel-${id}-fontSize`, `logPanel-${id}-fontSize`,
"normal" "normal"
); );
const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false); const [wrapText, setTextWrap] = usePersistentState(`logPanel-${id}-wrapText`, false);
const [showFilter, setShowFilter] = usePersistentState(`logPanel-${id}-showFilter`, false);
const textWrapClass = useMemo(() => { const textWrapClass = useMemo(() => {
return wrapText ? "whitespace-pre-wrap" : "whitespace-pre"; return wrapText ? "whitespace-pre-wrap" : "whitespace-pre";
@@ -60,14 +69,19 @@ export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
}); });
}, []); }, []);
const togglePanelState = useCallback(() => { const toggleWrapText = useCallback(() => {
setPanelState((prev) => { setTextWrap((prev) => !prev);
if (prev === "small") return "max";
if (prev === "hide") return "small";
return "hide";
});
}, []); }, []);
const toggleFilter = useCallback(() => {
if (showFilter) {
setShowFilter(false);
setFilterRegex(""); // Clear filter when closing
} else {
setShowFilter(true);
}
}, [filterRegex, setFilterRegex, showFilter]);
const fontSizeClass = useMemo(() => { const fontSizeClass = useMemo(() => {
switch (fontSize) { switch (fontSize) {
case "xxs": case "xxs":
@@ -101,60 +115,48 @@ export const LogPanel = ({ id, title, logData, className }: LogPanelProps) => {
}, [filteredLogs]); }, [filteredLogs]);
return ( return (
<div className={`bg-surface border border-border rounded-lg overflow-hidden flex flex-col ${className || ""}`}> <div className="bg-surface border border-border rounded-lg overflow-hidden flex flex-col h-full">
<div className="p-4 border-b border-border bg-secondary"> <div className="p-4 border-b border-border bg-secondary">
<div className="flex flex-col md:flex-row md:items-center md:justify-between gap-4"> <div className="flex items-center justify-between">
{/* Title - Always full width on mobile, normal on desktop */} <h3 className="m-0 text-lg p-0">{title}</h3>
<div className="w-full md:w-auto" onClick={togglePanelState}>
<h3 className="m-0 text-lg">{title}</h3> <div className="flex gap-2 items-center">
<button className="btn" onClick={toggleFontSize}>
<RiFontSize />
</button>
<button className="btn" onClick={toggleWrapText}>
{wrapText ? <RiTextWrap /> : <RiAlignJustify />}
</button>
<button className="btn" onClick={toggleFilter}>
{showFilter ? <RiMenuSearchFill /> : <RiMenuSearchLine />}
</button>
</div> </div>
</div>
<div className="flex flex-col sm:flex-row gap-4 w-full md:w-auto"> {/* Filtering Options - Full width on mobile, normal on desktop */}
{/* Sizing Buttons - Stacks vertically on mobile */} {showFilter && (
<div className="flex flex-wrap gap-2"> <div className="mt-2 w-full">
<button className="btn" onClick={togglePanelState}> <div className="flex gap-2 items-center w-full">
size: {panelState}
</button>
<button className="btn" onClick={toggleFontSize}>
font: {fontSize}
</button>
<button className="btn" onClick={() => setTextWrap((prev) => !prev)}>
{wrapText ? "wrap" : "wrap off"}
</button>
</div>
{/* Filtering Options - Full width on mobile, normal on desktop */}
<div className="flex flex-1 min-w-0 gap-2">
<input <input
type="text" type="text"
className="flex-1 min-w-[120px] text-sm border p-2 rounded" className="w-full text-sm border p-2 rounded"
placeholder="Filter logs..." placeholder="Filter logs..."
value={filterRegex} value={filterRegex}
onChange={(e) => setFilterRegex(e.target.value)} onChange={(e) => setFilterRegex(e.target.value)}
/> />
<button className="btn" onClick={() => setFilterRegex("")}> <button className="pl-2" onClick={() => setFilterRegex("")}>
Clear <RiCloseCircleFill size="24" />
</button> </button>
</div> </div>
</div> </div>
</div> )}
</div>
<div className="bg-background font-mono text-sm flex-1 overflow-hidden">
<pre ref={preTagRef} className={`${textWrapClass} ${fontSizeClass} h-full overflow-auto p-4`}>
{filteredLogs}
</pre>
</div> </div>
{panelState !== "hide" && (
<div className="flex-1 bg-background font-mono text-sm leading-[1.4] p-3">
<pre
ref={preTagRef}
className={`flex-1 p-4 overflow-y-auto whitespace-pre min-h-0 ${textWrapClass} ${fontSizeClass}`}
style={{
maxHeight: panelState === "max" ? "1500px" : "500px",
}}
>
{filteredLogs}
</pre>
</div>
)}
</div> </div>
); );
}; };
export default LogViewer; export default LogViewer;
+147 -52
View File
@@ -1,19 +1,50 @@
import { useState, useEffect, useCallback } from "react"; import { useState, useCallback, useMemo } from "react";
import { useAPI } from "../contexts/APIProvider"; import { useAPI } from "../contexts/APIProvider";
import { LogPanel } from "./LogViewer"; import { LogPanel } from "./LogViewer";
import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import { useTheme } from "../contexts/ThemeProvider";
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri";
export default function ModelsPage() { export default function ModelsPage() {
const { models, enableModelUpdates, unloadAllModels, loadModel, upstreamLogs, enableUpstreamLogs } = useAPI(); const { isNarrow } = useTheme();
const [isUnloading, setIsUnloading] = useState(false); const direction = isNarrow ? "vertical" : "horizontal";
const { upstreamLogs } = useAPI();
useEffect(() => { return (
enableModelUpdates(true); <PanelGroup direction={direction} className="gap-2" autoSaveId={"models-panel-group"}>
enableUpstreamLogs(true); <Panel id="models" defaultSize={50} minSize={isNarrow ? 0 : 25} maxSize={100} collapsible={isNarrow}>
return () => { <ModelsPanel />
enableModelUpdates(false); </Panel>
enableUpstreamLogs(false);
}; <PanelResizeHandle
}, []); className={
direction === "horizontal"
? "w-2 h-full bg-primary hover:bg-success transition-colors rounded"
: "w-full h-2 bg-primary hover:bg-success transition-colors rounded"
}
/>
<Panel collapsible={true} defaultSize={50} minSize={0}>
<div className="flex flex-col h-full space-y-4">
{direction === "horizontal" && <StatsPanel />}
<div className="flex-1 min-h-0">
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} />
</div>
</div>
</Panel>
</PanelGroup>
);
}
function ModelsPanel() {
const { models, loadModel, unloadAllModels } = useAPI();
const [isUnloading, setIsUnloading] = useState(false);
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
const filteredModels = useMemo(() => {
return models.filter((model) => showUnlisted || !model.unlisted);
}, [models, showUnlisted]);
const handleUnloadAllModels = useCallback(async () => { const handleUnloadAllModels = useCallback(async () => {
setIsUnloading(true); setIsUnloading(true);
@@ -22,57 +53,121 @@ export default function ModelsPage() {
} catch (e) { } catch (e) {
console.error(e); console.error(e);
} finally { } finally {
// at least give it a second to show the unloading message
setTimeout(() => { setTimeout(() => {
setIsUnloading(false); setIsUnloading(false);
}, 1000); }, 1000);
} }
}, []); }, [unloadAllModels]);
const toggleIdorName = useCallback(() => {
setShowIdorName((prev) => (prev === "name" ? "id" : "name"));
}, [showIdorName]);
return ( return (
<div className="h-screen"> <div className="card h-full flex flex-col">
<div className="flex flex-col md:flex-row gap-4"> <div className="shrink-0">
{/* Left Column */} <h2>Models</h2>
<div className="w-full md:w-1/2 flex items-top"> <div className="flex justify-between">
<div className="card w-full"> <div className="flex gap-2">
<h2 className="">Models</h2> <button className="btn flex items-center gap-2" onClick={toggleIdorName} style={{ lineHeight: "1.2" }}>
<button className="btn" onClick={handleUnloadAllModels} disabled={isUnloading}> <RiSwapBoxFill /> {showIdorName === "id" ? "ID" : "Name"}
{isUnloading ? "Unloading..." : "Unload All Models"}
</button> </button>
<table className="w-full mt-4">
<thead>
<tr className="border-b border-primary">
<th className="text-left p-2">Name</th>
<th className="text-left p-2"></th>
<th className="text-left p-2">State</th>
</tr>
</thead>
<tbody>
{models.map((model) => (
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
<td className="p-2">
<a href={`/upstream/${model.id}/`} className="underline" target="_blank">
{model.id}
</a>
</td>
<td className="p-2">
<button className="btn btn--sm" disabled={model.state !== "stopped"} onClick={() => loadModel(model.id)}>Load</button>
</td>
<td className="p-2">
<span className={`status status--${model.state}`}>{model.state}</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
{/* Right Column */} <button
<div className="w-full md:w-1/2 flex items-top"> className="btn flex items-center gap-2"
<LogPanel id="modelsupstream" title="Upstream Logs" logData={upstreamLogs} className="h-full" /> onClick={() => setShowUnlisted(!showUnlisted)}
style={{ lineHeight: "1.2" }}
>
{showUnlisted ? <RiEyeFill /> : <RiEyeOffFill />} unlisted
</button>
</div>
<button className="btn flex items-center gap-2" onClick={handleUnloadAllModels} disabled={isUnloading}>
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
</button>
</div> </div>
</div> </div>
<div className="flex-1 overflow-y-auto">
<table className="w-full">
<thead className="sticky top-0 bg-card z-10">
<tr className="border-b border-primary bg-surface">
<th className="text-left p-2">{showIdorName === "id" ? "Model ID" : "Name"}</th>
<th className="text-left p-2"></th>
<th className="text-left p-2">State</th>
</tr>
</thead>
<tbody>
{filteredModels.map((model) => (
<tr key={model.id} className="border-b hover:bg-secondary-hover border-border">
<td className={`p-2 ${model.unlisted ? "text-txtsecondary" : ""}`}>
<a href={`/upstream/${model.id}/`} className={`underline`} target="_blank">
{showIdorName === "id" ? model.id : model.name !== "" ? model.name : model.id}
</a>
{model.description !== "" && (
<p className={model.unlisted ? "text-opacity-70" : ""}>
<em>{model.description}</em>
</p>
)}
</td>
<td className="p-2 w-[50px]">
<button
className="btn btn--sm"
disabled={model.state !== "stopped"}
onClick={() => loadModel(model.id)}
>
Load
</button>
</td>
<td className="p-2 w-[75px]">
<span className={`status status--${model.state}`}>{model.state}</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
);
}
function StatsPanel() {
const { metrics } = useAPI();
const [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond] = useMemo(() => {
const totalRequests = metrics.length;
if (totalRequests === 0) {
return [0, 0, 0];
}
const totalInputTokens = metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = metrics.reduce((sum, m) => sum + m.output_tokens, 0);
const avgTokensPerSecond = (metrics.reduce((sum, m) => sum + m.tokens_per_second, 0) / totalRequests).toFixed(2);
return [totalRequests, totalInputTokens, totalOutputTokens, avgTokensPerSecond];
}, [metrics]);
return (
<div className="card">
<div className="rounded-lg overflow-hidden border border-gray-200">
<table className="w-full">
<tbody>
<tr>
<th className="p-2 font-medium border-b border-gray-200 text-right">Requests</th>
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Processed</th>
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Generated</th>
<th className="p-2 font-medium border-l border-b border-gray-200 text-right">Tokens/Sec</th>
</tr>
<tr>
<td className="p-2 text-right border-r border-gray-200">{totalRequests}</td>
<td className="p-2 text-right border-r border-gray-200">
{new Intl.NumberFormat().format(totalInputTokens)}
</td>
<td className="p-2 text-right border-r border-gray-200">
{new Intl.NumberFormat().format(totalOutputTokens)}
</td>
<td className="p-2 text-right">{avgTokensPerSecond}</td>
</tr>
</tbody>
</table>
</div>
</div> </div>
); );
} }