Compare commits

...

12 Commits

Author SHA1 Message Date
Benson Wong 0b31ccacc1 ui-svelte: fix histogram calculation (#695)
- Fix the histogram calculation to use server provided generation
tokens/second.
- Move histogram to Activities page where it can exist with the rest of
the token metrics

Fixes #681
2026-04-22 23:42:39 -07:00
Bryan Gahagan 5938dbee8f Push unified docker images on scheduled runs (#694)
Fixes #693
2026-04-22 20:46:51 -07:00
Benson Wong 66639e83f7 proxy: replace fsnotify with stat-poll watcher and add SIGHUP reload (#685)
The fsnotify-based config watcher does not work reliably when the config
file is bind-mounted into a Docker container as an individual file, and
mishandles k8s ConfigMap projections (atomically swapped symlinks).
Replace it with a small os.Stat-polling watcher and add SIGHUP as an
explicit reload signal.

- new proxy/configwatcher package: 2s os.Stat poller, follows symlinks,
  fires on mtime/size change and on missing -> present transitions
- SIGHUP triggers reload unconditionally (works without --watch-config)
  via the same ConfigFileChangedEvent pipeline so the UI sees identical
  state transitions
- watcher goroutine now exits cleanly on shutdown via a context
- drop github.com/fsnotify/fsnotify dependency

fixes #682
2026-04-21 23:21:48 -07:00
Benson Wong 625b296720 docker/unified: add uv via pip install (#681)
Install uv after the cpp tool binaries are copied and before the
llama-swap binary, enabling `uv run` usage for Python-based inference
backends like vLLM.

- add python3-pip to runtime apt installs
- add `pip install uv --break-system-packages` after cpp installs

fixes #628

Co-authored-by: Claude <noreply@anthropic.com>
2026-04-20 20:55:51 -07:00
Benson Wong 231e62291c proxy: fix matrix race and process stop bug (#677)
- matrix.go change logic to consider any proxy.Process not in
StateStopped or StateShutdown
- process.StopImmediately, and Stop() which called it had a subtle bug
where it only handled state transitions from StateReady to
StateStopping. StateStarting -> StateStopping was ignored completely.

fix: #670
2026-04-20 00:21:11 -07:00
Benson Wong 57ac666598 .github/workflows: tweak push ghcr conditional (#676) 2026-04-19 13:56:26 -07:00
Benson Wong 69728301f5 .github/workflows: add toggle for pushing unified images to github (#672)
Add ability to dispatch (manually run) unified container builds in github without push to ghcr.io.
2026-04-19 10:10:48 -07:00
Benson Wong c176fa70f1 docker/unified: add spirv-headers to fix vulkan build (#669) 2026-04-18 12:18:10 -07:00
Benson Wong 5e3c646829 proxy: compress captures with zstd (#668)
The previous captures were saved uncompressed in memory. In agentic
workflows there can be many turns with each request containing the
previous context in the body with a lot of redundant data. Use zstd to
compress the request and response data before keeping a copy of memory.

Results: 

- Average Percentage Saved: 73.19%
- Average Compression Factor: ~6.77:1
2026-04-17 23:29:37 -07:00
Benson Wong c3f0d43e6e proxy: fix race conditions during swap (#667)
I pointed Opus 4.7 (high effort) at proxy.ProcessGroup to identify any
race conditions in the swapping code. It found a race condition where
there is a small window in the fast path for routing a request to a
loaded model. There is a very small window where:

- model M1 is loaded and ready for requests
- a request, R1, for M1 comes in 
- a request, R2, for M2 comes in almost immediately after
- R1 acquires the lock, sees M1 is loaded (fast path), releases the lock
`[race window]` and the request is ready to be forwarded
- the race window occurs between the release of the lock and the request
being forwarded
  - the lock is released so requests can be handled concurrently 
- R2 comes in within the `[race window]`, acquires the lock, triggers a
model swap to M2. stopping M1
- R1 is forwarded to a model that is unloaded or in the process of
shutting down creating an error response

In deployed systems the race window is very small and doesn't happen
often. However with #635 and PR #656 I though this deserved a bit more
attention. It is not concluded that this race is the cause of #635 but
the race is likely to happen more often under sustained or high load.

AI Note: Opus 4.7 x-high effort took about an hour to write the original
patch. With the pattern discovered the fix to matrix.go was very quick.
GLM 5.1 using the previous established patterns was able to easily write
the fix for ProcessGroup.StopProcesses().

Supersedes: #656
Updates: #277, #635
2026-04-17 21:23:17 -07:00
Benson Wong f6cf9f5844 proxy: Refactor tests (#660)
- use YAML for test configurations
- remove most uses of simple-responder, opting to use
process.testHandler

Fixes #655
2026-04-16 22:47:42 -07:00
Benson Wong 121fd93ad8 Makefile: restore linux arm64 targets
Fix #641
2026-04-14 22:05:39 -07:00
29 changed files with 2098 additions and 710 deletions
+2 -11
View File
@@ -19,9 +19,6 @@ jobs:
run-tests: run-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults:
run:
working-directory: ui-svelte
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -32,11 +29,5 @@ jobs:
cache: 'npm' cache: 'npm'
cache-dependency-path: ui-svelte/package-lock.json cache-dependency-path: ui-svelte/package-lock.json
- name: Install dependencies - name: Run UI tests
run: npm ci run: make test-ui
- name: Type check
run: npm run check
- name: Run tests
run: npm test
+6 -1
View File
@@ -36,6 +36,11 @@ on:
type: boolean type: boolean
required: false required: false
default: true default: true
push_to_ghcr:
description: "Push images to ghcr.io"
type: boolean
required: false
default: true
permissions: permissions:
contents: read contents: read
@@ -116,7 +121,7 @@ jobs:
docker/unified/build-image.sh --${{ matrix.backend }} docker/unified/build-image.sh --${{ matrix.backend }}
- name: Push to GitHub Container Registry - name: Push to GitHub Container Registry
if: ${{ !env.ACT }} if: ${{ !env.ACT && (github.event_name == 'schedule' || inputs.push_to_ghcr == true) }}
run: | run: |
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}" BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
DATE_TAG=$(date -u +%Y-%m-%d) DATE_TAG=$(date -u +%Y-%m-%d)
+1
View File
@@ -24,6 +24,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`. - Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory - Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
- Use `make test-all` before completing work. This includes long running concurrency tests. - Use `make test-all` before completing work. This includes long running concurrency tests.
- Use `make test-ui` after making changes to the UI in ui-svelte/
### Commit message example format: ### Commit message example format:
+13 -4
View File
@@ -48,10 +48,15 @@ mac: ui
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64 GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
# Build Linux binary # Build Linux binary
linux: ui linux: linux-arm64 linux-amd64
@echo "Building Linux binary..."
linux-amd64: ui
@echo "Building Linux AMD64 binary..."
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64 GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
#GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
linux-arm64: ui
@echo "Building Linux ARM64 binary..."
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
# Build Windows binary # Build Windows binary
windows: ui windows: ui
@@ -92,5 +97,9 @@ wol-proxy: $(BUILD_DIR)
@echo "Building wol-proxy" @echo "Building wol-proxy"
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
test-ui:
cd ui-svelte && npm ci && npm run check && npm test
# Phony targets # Phony targets
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy .PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev test-ui wol-proxy
.PHONE: linux linux-arm64 linux-amd64
+183
View File
@@ -0,0 +1,183 @@
# Improve Testability (#655)
## Current Pain Points
1. **Tests bypass config loading** - ~80% of tests build `config.Config` structs directly, skipping YAML parsing, env var substitution, macro expansion, and `${PORT}` assignment. Config bugs in those paths go untested.
2. **simple-responder is everywhere** - Every proxy/routing test launches a real subprocess, waits for health checks (~healthCheckTimeout: 15), and manages process lifecycle just to test HTTP routing. Most of that overhead is wasted.
3. **Port counter is fragile** - A global `nextTestPort` counter starting at 12000 with a mutex. Parallel tests or leftover processes can collide.
## Stages
### Stage 1: YAML-based test config helper
**Goal:** Tests go through the real `LoadConfigFromReader` path instead of hand-building structs.
**Effort:** Low | **Impact:** Config bugs caught earlier | **Risk:** None
Create a test helper in `proxy/helpers_test.go`:
```go
// testConfigFromYAML substitutes simple-responder paths and loads through
// the real config pipeline (env vars, macros, port assignment, etc.)
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
t.Helper()
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
require.NoError(t, err)
return cfg
}
```
Tests would then look like:
```go
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
model2:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model2
`)
proxy := New(config)
// ... same assertions
}
```
**Why this stage first:** Zero production code changes. Pure test-side refactoring. Can be done incrementally - migrate tests one at a time. Each migrated test now validates the full config pipeline.
**Scope:** ~20-30 tests in `proxymanager_test.go`, `processgroup_test.go`, `peerproxy_test.go`.
### Stage 2: Injected test handler (eliminate simple-responder for routing tests)
**Goal:** Replace simple-responder subprocess launches with an injected `http.Handler` for tests that don't specifically test process lifecycle.
**Effort:** Medium | **Impact:** 10-100x faster routing tests | **Risk:** Low (additive, no existing code broken)
Add a `testHandler http.Handler` field to `Process`. When set, `ProxyRequest` delegates directly to this handler instead of going through the reverse proxy. No subprocess, no health checks, no TCP roundtrip.
**2a. Add testHandler to Process:**
```go
// In Process struct (process.go):
testHandler http.Handler // set only in tests; bypasses subprocess and reverse proxy
```
In `Process.Start()`, skip subprocess + health check when handler is set:
```go
func (p *Process) start() error {
if p.testHandler != nil {
p.setState(StateReady)
return nil
}
// existing subprocess logic...
}
```
In `Process.ProxyRequest()`, delegate directly to the handler:
```go
// Before the reverseProxy.ServeHTTP call:
if p.testHandler != nil {
p.testHandler.ServeHTTP(w, r)
return
}
```
**2b. Test helper to create the handler:**
```go
// newTestHandler returns an http.Handler that mimics llama.cpp's API
// (same endpoints as simple-responder).
func newTestHandler(respond string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { ... })
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { ... })
// ... other endpoints
return mux
}
```
Tests for routing/auth/CORS/streaming then become:
```go
func TestProxyManager_AuthRequired(t *testing.T) {
handler := newTestHandler("model1")
config := testConfigFromYAML(t, `
healthCheckTimeout: 15
logLevel: error
requiredAPIKeys: [test-key]
models:
model1:
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
`)
pm := NewProxyManager(config)
// inject handler — skips subprocess, health check, port allocation
pm.processGroups["model1"].process.testHandler = handler
}
```
**Why this matters:** The handler is called directly in-process. No subprocess spawn, no health check timeout, no port allocation, no TCP roundtrip, no reverse proxy overhead. Routing tests go from ~100ms each (process startup + health check) to ~1ms. Unlike an `httptest.Server` approach, there are zero network hops.
**Why not blank-cmd + proxy URL:** A blank `cmd` with a `proxy` field pointing at `httptest.Server` still requires a real TCP roundtrip through the reverse proxy and introduces "external process" semantics to the config schema. Injecting the handler directly keeps it purely a test concern with no config changes.
**Scope:** Most tests in `proxymanager_test.go` (auth, CORS, model listing, streaming, peer proxy), `peerproxy_test.go`, `metrics_monitor_test.go`.
### Stage 3: Migrate tests incrementally
**Goal:** Convert existing tests to use the Stage 1 + Stage 2 helpers.
**Effort:** Medium | **Impact:** Cleaner, more reliable tests | **Risk:** None
Priority order:
1. `proxymanager_test.go` routing tests (highest count, most repetition)
2. `peerproxy_test.go` (straightforward, all HTTP routing)
3. `metrics_monitor_test.go` (capture logic doesn't need real processes)
4. `processgroup_test.go` swap tests (keep simple-responder for actual swap lifecycle tests)
Tests that **must keep simple-responder:**
- Process lifecycle: start/stop, SIGKILL, SIGTERM, TTL expiry, health check failures, failed start counting
- ProcessGroup swap concurrency (the port-collision test in `TestProcessGroup_ProxyRequestSwapIsTrueParallel`)
**Scope:** ~60-70% of tests can drop simple-responder.
### Stage 4 (optional): Process interface for ProcessGroup
**Goal:** Enable pure unit tests of ProcessGroup's swap/exclusive/concurrency logic without any HTTP server at all.
**Effort:** High | **Impact:** Pure unit tests possible | **Risk:** Medium (refactor core code)
```go
type ProcessController interface {
Start() error
Stop(StopStrategy)
ProxyRequest(http.ResponseWriter, *http.Request) error
CurrentState() ProcessState
ID() string
SetState(ProcessState) // for test setup
}
```
This requires:
- Extracting the interface
- A `MockProcess` implementation
- Refactoring `ProcessGroup` to use the interface instead of `*Process`
**Recommendation:** Only do this if ProcessGroup grows significantly more complex. Stages 1-3 give 80% of the benefit for 20% of the effort.
## Effort/Impact Summary
| Stage | Effort | Impact | Risk |
|-------|--------|--------|------|
| 1. YAML config helper | Low | Config bugs caught earlier | None |
| 2. Injected test handler | Medium | 10-100x faster routing tests | Low |
| 3. Migrate tests | Medium | Cleaner, more reliable tests | None |
| 4. Process interface | High | Pure unit tests possible | Medium |
**Recommended approach:** Do stages 1-3 in order. Each stage is independently valuable and can ship on its own. Stage 4 is deferred unless there's a specific need.
+5 -1
View File
@@ -42,6 +42,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git python3 python3-pip libssl-dev \ build-essential cmake git python3 python3-pip libssl-dev \
curl ca-certificates ccache make wget software-properties-common \ curl ca-certificates ccache make wget software-properties-common \
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \ libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
spirv-headers \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
WORKDIR /build WORKDIR /build
@@ -148,7 +149,7 @@ ARG IK_LLAMA_COMMIT_HASH=unknown
ARG RUN_UID=0 ARG RUN_UID=0
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
python3-numpy python3-sentencepiece \ python3-numpy python3-sentencepiece python3-pip \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Create non-root user when RUN_UID != 0 # Create non-root user when RUN_UID != 0
@@ -179,6 +180,9 @@ COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
# Copy ik-llama-server (CUDA only; empty copy for vulkan) # Copy ik-llama-server (CUDA only; empty copy for vulkan)
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/ COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
# Install uv
RUN pip install uv --break-system-packages
# Copy llama-swap binary # Copy llama-swap binary
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/ COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
COPY --from=llama-swap-download /install/llama-swap-version /tmp/ COPY --from=llama-swap-download /install/llama-swap-version /tmp/
+1 -1
View File
@@ -4,8 +4,8 @@ go 1.26.1
require ( require (
github.com/billziss-gh/golib v0.2.0 github.com/billziss-gh/golib v0.2.0
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/klauspost/compress v1.18.5
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
+2 -2
View File
@@ -11,8 +11,6 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -34,6 +32,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
+46 -40
View File
@@ -9,14 +9,15 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"runtime"
"syscall" "syscall"
"time" "time"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy" "github.com/mostlygeek/llama-swap/proxy"
"github.com/mostlygeek/llama-swap/proxy/config" "github.com/mostlygeek/llama-swap/proxy/config"
"github.com/mostlygeek/llama-swap/proxy/configwatcher"
) )
var ( var (
@@ -79,6 +80,17 @@ func main() {
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Reload signals (SIGHUP on POSIX, none on Windows — Windows does not
// deliver SIGHUP). Always wired up so `kill -HUP` works regardless of
// --watch-config.
reloadChan := make(chan os.Signal, 1)
if runtime.GOOS != "windows" {
signal.Notify(reloadChan, syscall.SIGHUP)
}
// Context that bounds the lifetime of background watcher goroutines.
watcherCtx, watcherCancel := context.WithCancel(context.Background())
// Create server with initial handler // Create server with initial handler
srv := &http.Server{ srv := &http.Server{
Addr: *listenStr, Addr: *listenStr,
@@ -121,52 +133,45 @@ func main() {
// load the initial proxy manager // load the initial proxy manager
reloadProxyManager() reloadProxyManager()
debouncedReload := debounce(time.Second, reloadProxyManager) debouncedReload := debounce(time.Second, reloadProxyManager)
if *watchConfig {
defer event.On(func(e proxy.ConfigFileChangedEvent) {
if e.ReloadingState == proxy.ReloadingStateStart {
debouncedReload()
}
})()
fmt.Println("Watching Configuration for changes") // Listen for ConfigFileChangedEvent unconditionally so SIGHUP and the
// poll-based watcher both feed the same debounced reload pipeline. The
// UI also listens for the matching ReloadingStateEnd emitted from
// reloadProxyManager.
defer event.On(func(e proxy.ConfigFileChangedEvent) {
if e.ReloadingState == proxy.ReloadingStateStart {
debouncedReload()
}
})()
// SIGHUP (or platform-equivalent) → reload. Back-to-back signals collapse
// to one reload via the debounce window, which is the desired behavior.
go func() {
for range reloadChan {
fmt.Println("Received reload signal, reloading configuration")
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
}
}()
if *watchConfig {
go func() { go func() {
absConfigPath, err := filepath.Abs(*configPath) absConfigPath, err := filepath.Abs(*configPath)
if err != nil { if err != nil {
fmt.Printf("Error getting absolute path for watching config file: %v\n", err) fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
return return
} }
watcher, err := fsnotify.NewWatcher() fmt.Println("Watching configuration for changes (poll-based, 2s interval)")
if err != nil { (&configwatcher.Watcher{
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err) Path: absConfigPath,
return Interval: configwatcher.DefaultInterval,
} OnChange: func() {
event.Emit(proxy.ConfigFileChangedEvent{
configDir := filepath.Dir(absConfigPath) ReloadingState: proxy.ReloadingStateStart,
err = watcher.Add(configDir) })
if err != nil { },
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err) }).Run(watcherCtx)
return
}
defer watcher.Close()
for {
select {
case changeEvent := <-watcher.Events:
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
// the change for k8s configmap
event.Emit(proxy.ConfigFileChangedEvent{
ReloadingState: proxy.ReloadingStateStart,
})
}
case err := <-watcher.Errors:
log.Printf("File watcher error: %v", err)
}
}
}() }()
} }
@@ -174,6 +179,7 @@ func main() {
go func() { go func() {
sig := <-sigChan sig := <-sigChan
fmt.Printf("Received signal %v, shutting down...\n", sig) fmt.Printf("Received signal %v, shutting down...\n", sig)
watcherCancel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
+85
View File
@@ -0,0 +1,85 @@
// Package configwatcher provides a simple cross-platform file watcher based
// on os.Stat polling. It works correctly inside Docker containers where the
// config file is bind-mounted as an individual file, and for k8s ConfigMap
// projections (which present the file as a symlink to an atomically swapped
// target) — both cases where inotify-based watchers are unreliable.
package configwatcher
import (
"context"
"errors"
"io/fs"
"log"
"os"
"time"
)
const DefaultInterval = 2 * time.Second
type Watcher struct {
Path string
Interval time.Duration
OnChange func()
}
type snapshot struct {
exists bool
modTime time.Time
size int64
}
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
// OnChange whenever the file's modification time or size changes, or when
// the file reappears after being missing. The baseline poll establishes
// initial state and does not fire OnChange.
func (w *Watcher) Run(ctx context.Context) {
interval := w.Interval
if interval <= 0 {
interval = DefaultInterval
}
prev := stat(w.Path)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cur := stat(w.Path)
if changed(prev, cur) && w.OnChange != nil {
w.OnChange()
}
prev = cur
}
}
}
func stat(path string) snapshot {
fi, err := os.Stat(path)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
log.Printf("configwatcher: stat %s: %v", path, err)
}
return snapshot{}
}
return snapshot{
exists: true,
modTime: fi.ModTime(),
size: fi.Size(),
}
}
func changed(prev, cur snapshot) bool {
// Present → missing: stay quiet (likely a transient rename-style write).
// Missing → present: fire so we reload as soon as the file comes back.
if !cur.exists {
return false
}
if !prev.exists {
return true
}
return !prev.modTime.Equal(cur.modTime) || prev.size != cur.size
}
+191
View File
@@ -0,0 +1,191 @@
package configwatcher
import (
"context"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testInterval = 25 * time.Millisecond
// startWatcher launches w.Run in a goroutine and returns a function that
// cancels the context and waits for Run to return.
func startWatcher(t *testing.T, w *Watcher) func() {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.Run(ctx)
close(done)
}()
return func() {
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("watcher did not stop within 2s of cancel")
}
}
}
// waitForCount blocks until counter reaches want or timeout elapses.
func waitForCount(t *testing.T, counter *int64, want int64, timeout time.Duration) bool {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if atomic.LoadInt64(counter) >= want {
return true
}
time.Sleep(5 * time.Millisecond)
}
return false
}
func TestWatcher_NoFireOnBaseline(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
var n int64
stop := startWatcher(t, &Watcher{
Path: path,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 5)
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
}
func TestWatcher_DetectsModTimeChange(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
// Force a known baseline mtime.
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
require.NoError(t, os.Chtimes(path, base, base))
var n int64
stop := startWatcher(t, &Watcher{
Path: path,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
// Let the baseline settle.
time.Sleep(testInterval * 2)
// Bump mtime well above the baseline so low-resolution filesystems still notice.
require.NoError(t, os.Chtimes(path, base.Add(10*time.Second), base.Add(10*time.Second)))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
}
func TestWatcher_DetectsSizeChangeWithSameModTime(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
fi, err := os.Stat(path)
require.NoError(t, err)
originalMtime := fi.ModTime()
var n int64
stop := startWatcher(t, &Watcher{
Path: path,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
require.NoError(t, os.WriteFile(path, []byte("aaaaa"), 0o644))
// Reset mtime back to the original so size is the only signal.
require.NoError(t, os.Chtimes(path, originalMtime, originalMtime))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire on size change")
}
func TestWatcher_SymlinkTargetSwap(t *testing.T) {
dir := t.TempDir()
targetA := filepath.Join(dir, "targetA")
targetB := filepath.Join(dir, "targetB")
link := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(targetA, []byte("AAAA"), 0o644))
require.NoError(t, os.WriteFile(targetB, []byte("BBBBBBBB"), 0o644))
if err := os.Symlink(targetA, link); err != nil {
if runtime.GOOS == "windows" {
t.Skipf("symlink creation requires privilege on Windows: %v", err)
}
t.Fatalf("os.Symlink: %v", err)
}
var n int64
stop := startWatcher(t, &Watcher{
Path: link,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
// Atomic symlink swap (k8s ConfigMap pattern): create new symlink at a
// temp name, then rename over the existing one.
tmpLink := filepath.Join(dir, "config.yaml.tmp")
require.NoError(t, os.Symlink(targetB, tmpLink))
require.NoError(t, os.Rename(tmpLink, link))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after symlink target swap")
}
func TestWatcher_FileMissingThenReturns(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
var n int64
stop := startWatcher(t, &Watcher{
Path: path,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
require.NoError(t, os.Remove(path))
time.Sleep(testInterval * 3)
require.Equal(t, int64(0), atomic.LoadInt64(&n), "removal alone must not fire")
require.NoError(t, os.WriteFile(path, []byte("b"), 0o644))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when file returns")
}
func TestWatcher_ContextCancelStopsRun(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
w := &Watcher{Path: path, Interval: testInterval}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() { w.Run(ctx); close(done) }()
time.Sleep(testInterval * 2)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Run did not return within 2s of cancel")
}
}
+202
View File
@@ -1,15 +1,22 @@
package proxy package proxy
import ( import (
"encoding/json"
"fmt" "fmt"
"io"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"sync" "sync"
"testing" "testing"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy/config" "github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -66,6 +73,16 @@ func getTestPort() int {
return port return port
} }
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
t.Helper()
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
require.NoError(t, err)
return cfg
}
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig { func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort()) return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
} }
@@ -88,3 +105,188 @@ proxy: "http://127.0.0.1:%d"
return cfg return cfg
} }
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
// model IDs to their respond strings; if a model ID is not in the map, the model
// ID itself is used.
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
for _, pg := range pm.processGroups {
for modelID, process := range pg.processes {
respond := modelID
if r, ok := modelResponses[modelID]; ok {
respond = r
}
process.testHandler = newTestHandler(respond)
}
}
}
// newTestHandler returns an http.Handler that mimics simple-responder's API.
// It supports the endpoints that routing tests depend on, without launching
// any subprocess or binding any port.
func newTestHandler(respond string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
isStreaming := r.URL.Query().Get("stream") == "true"
if wait := r.URL.Query().Get("wait"); wait != "" {
if d, err := time.ParseDuration(wait); err == nil {
time.Sleep(d)
}
}
if isStreaming {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher := w.(http.Flusher)
for i := 0; i < 10; i++ {
data, _ := json.Marshal(map[string]any{
"created": time.Now().Unix(),
"choices": []map[string]any{
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
},
})
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
flusher.Flush()
}
finalData, _ := json.Marshal(map[string]any{
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
"timings": map[string]any{
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
"predicted_ms": 17, "predicted_per_second": 10,
},
})
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
flusher.Flush()
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
flusher.Flush()
} else {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"h_content_length": r.Header.Get("Content-Length"),
"request_body": string(bodyBytes),
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
"timings": map[string]any{
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
"predicted_ms": 17, "predicted_per_second": 10,
},
})
}
})
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
if modelName != respond {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
return
}
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
})
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
})
})
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"responseMessage": respond,
"usage": map[string]any{
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
},
})
})
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
return
}
model := r.FormValue("model")
if model == "" {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
return
}
file, _, err := r.FormFile("file")
if err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
return
}
fileBytes, _ := io.ReadAll(file)
file.Close()
json.NewEncoder(w).Encode(map[string]any{
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
"model": model,
"h_content_type": r.Header.Get("Content-Type"),
"h_content_length": r.Header.Get("Content-Length"),
})
})
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
model := r.URL.Query().Get("model")
json.NewEncoder(w).Encode(map[string]any{
"voices": []string{"voice1"}, "model": model,
})
})
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprint(w, respond)
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
})
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
json.NewEncoder(w).Encode(map[string]any{
"model": modelName, "images": []string{},
})
})
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
modelName := gjson.GetBytes(body, "model").String()
json.NewEncoder(w).Encode(map[string]any{
"model": modelName, "images": []string{},
})
})
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"loras": []string{},
})
})
return mux
}
+33 -2
View File
@@ -147,6 +147,20 @@ type Matrix struct {
config config.Config config config.Config
proxyLogger *LogMonitor proxyLogger *LogMonitor
upstreamLogger *LogMonitor upstreamLogger *LogMonitor
// inflight tracks ProxyRequest calls that have released m.Lock but may
// not yet have incremented Process.inFlightRequests. A concurrent
// request that needs to evict models waits for inflight to drain under
// m.Lock before stopping anything. Without this, a request that
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
// races with Stop()'s Wait() and can be killed mid-request.
inflight sync.WaitGroup
// testDelayFastPath is a test-only hook invoked in the no-eviction path
// after m.Lock is released but before the request is dispatched to
// Process.ProxyRequest. Tests use it to park a request at the exact
// race window to deterministically reproduce the race.
testDelayFastPath func()
} }
// NewMatrix creates a Matrix from config. It creates a Process for every // NewMatrix creates a Matrix from config. It creates a Process for every
@@ -197,6 +211,13 @@ func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Req
// Evict models that need to be stopped // Evict models that need to be stopped
if len(result.Evict) > 0 { if len(result.Evict) > 0 {
// Wait for any in-flight ProxyRequest calls to register on their
// Process before stopping anything. Without this, a request that
// released m.Lock but has not yet incremented
// Process.inFlightRequests races with Stop() and can be killed
// mid-request.
m.inflight.Wait()
var wg sync.WaitGroup var wg sync.WaitGroup
for _, evictModel := range result.Evict { for _, evictModel := range result.Evict {
if p, exists := m.processes[evictModel]; exists { if p, exists := m.processes[evictModel]; exists {
@@ -209,8 +230,18 @@ func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Req
} }
wg.Wait() wg.Wait()
} }
// Register this request in inflight before releasing m.Lock so a
// concurrent eviction will wait for it to complete.
m.inflight.Add(1)
defer m.inflight.Done()
isFastPath := len(result.Evict) == 0
m.Unlock() m.Unlock()
if isFastPath && m.testDelayFastPath != nil {
m.testDelayFastPath()
}
// Proxy the request (Process handles on-demand start) // Proxy the request (Process handles on-demand start)
process.ProxyRequest(w, r) process.ProxyRequest(w, r)
return nil return nil
@@ -266,7 +297,7 @@ func (m *Matrix) Shutdown() {
wg.Wait() wg.Wait()
} }
// RunningModels returns model names currently in StateReady. // RunningModels returns model names currently in an active (non-stopped) state.
func (m *Matrix) RunningModels() []string { func (m *Matrix) RunningModels() []string {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
@@ -277,7 +308,7 @@ func (m *Matrix) RunningModels() []string {
func (m *Matrix) runningModels() []string { func (m *Matrix) runningModels() []string {
var running []string var running []string
for id, process := range m.processes { for id, process := range m.processes {
if process.CurrentState() == StateReady { if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
running = append(running, id) running = append(running, id)
} }
} }
+122
View File
@@ -1,7 +1,11 @@
package proxy package proxy
import ( import (
"net/http"
"net/http/httptest"
"runtime"
"testing" "testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config" "github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -169,6 +173,124 @@ func TestMatrixSolver_NothingRunning(t *testing.T) {
assert.Equal(t, []string{"g", "v"}, result.TargetSet) assert.Equal(t, []string{"g", "v"}, result.TargetSet)
} }
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
// cannot stop a process while an in-flight ProxyRequest for that process is
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
// matrix-level inflight tracking, the eviction's Stop() races with the
// pending request and kills it mid-start.
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
cfg := config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
ExpandedSets: []config.ExpandedSet{
{SetName: "s1", Models: []string{"model1"}},
{SetName: "s2", Models: []string{"model2"}},
},
Matrix: &config.MatrixConfig{},
}
m := NewMatrix(cfg, testLogger, testLogger)
defer m.StopProcesses(StopImmediately)
// Bypass real subprocesses so the test is fast and deterministic.
m.processes["model1"].testHandler = newTestHandler("model1")
m.processes["model2"].testHandler = newTestHandler("model2")
// Prime: run a request through model1 so it reaches StateReady and
// subsequent requests take the no-eviction path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
// Install fast-path hook that signals arrival and waits for release.
// This parks R2 at the race window — after m.Lock is released but
// before Process.inFlightRequests.Add(1).
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
m.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
// R2: no-eviction request for model1. Will pause at the hook.
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, m.ProxyRequest("model1", w2, req))
}()
// Deterministically wait for R2 to reach the race window.
<-r2Reached
// R3: request for model2 which requires evicting model1. Must wait for
// R2 to finish before touching model1.
r3Done := make(chan struct{})
w3 := httptest.NewRecorder()
go func() {
defer close(r3Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, m.ProxyRequest("model2", w3, req))
}()
// Spin until R3 has acquired m.Lock and entered the eviction path. In
// the fixed code, R3 then blocks on m.inflight.Wait() while still
// holding the lock, so TryLock keeps failing.
for m.TryLock() {
m.Unlock()
runtime.Gosched()
}
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
// state. In the fixed code R3 is blocked and nothing changes; in the
// buggy code R3 will Stop() model1 and start model2 within microseconds.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if m.processes["model1"].CurrentState() != StateReady ||
m.processes["model2"].CurrentState() != StateStopped {
break
}
done := false
select {
case <-r3Done:
done = true
default:
}
if done {
break
}
runtime.Gosched()
}
// Invariant: R3 must be blocked while R2 is still in flight.
select {
case <-r3Done:
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
default:
}
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
"model1 must stay Ready while an in-flight request is pending")
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
"model2 must not be started until R2 finishes and model1 is evicted")
// Release R2 and let both requests finish.
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
assert.Equal(t, http.StatusOK, w3.Code)
assert.Contains(t, w3.Body.String(), "model2")
}
func TestMatrixSolver_FullScenario(t *testing.T) { func TestMatrixSolver_FullScenario(t *testing.T) {
// Simulates the example config: // Simulates the example config:
// standard: [g,v], [q,v], [m,v] // standard: [g,v], [q,v], [m,v]
+101 -34
View File
@@ -13,10 +13,54 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
// zstdEncOptions are the shared zstd encoder options for maximum compression.
var zstdEncOptions = []zstd.EOption{
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
}
// zstdDecOptions are the shared zstd decoder options.
var zstdDecOptions = []zstd.DOption{}
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
var zstdEncPool = &sync.Pool{
New: func() interface{} {
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
return enc
},
}
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
var zstdDecPool = &sync.Pool{
New: func() interface{} {
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
return dec
},
}
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
// Returns compressed bytes and the original JSON byte count for logging.
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
jsonBytes, err := json.Marshal(c)
if err != nil {
return nil, 0, fmt.Errorf("marshal capture: %w", err)
}
enc := zstdEncPool.Get().(*zstd.Encoder)
defer zstdEncPool.Put(enc)
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
}
// decompressCapture decompresses zstd-compressed JSON and returns it.
func decompressCapture(data []byte) ([]byte, error) {
dec := zstdDecPool.Get().(*zstd.Decoder)
defer zstdDecPool.Put(dec)
return dec.DecodeAll(data, nil)
}
// TokenMetrics represents parsed token statistics from llama-server logs // TokenMetrics represents parsed token statistics from llama-server logs
type TokenMetrics struct { type TokenMetrics struct {
ID int `json:"id"` ID int `json:"id"`
@@ -40,18 +84,6 @@ type ReqRespCapture struct {
RespBody []byte `json:"resp_body"` RespBody []byte `json:"resp_body"`
} }
// Size returns the approximate memory usage of this capture in bytes
func (c *ReqRespCapture) Size() int {
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
for k, v := range c.ReqHeaders {
size += len(k) + len(v)
}
for k, v := range c.RespHeaders {
size += len(k) + len(v)
}
return size
}
// TokenMetricsEvent represents a token metrics event // TokenMetricsEvent represents a token metrics event
type TokenMetricsEvent struct { type TokenMetricsEvent struct {
Metrics TokenMetrics Metrics TokenMetrics
@@ -71,10 +103,10 @@ type metricsMonitor struct {
// capture fields // capture fields
enableCaptures bool enableCaptures bool
captures map[int]ReqRespCapture // map for O(1) lookup by ID captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
captureOrder []int // track insertion order for FIFO eviction captureOrder []int // track insertion order for FIFO eviction
captureSize int // current total size in bytes captureSize int // current total compressed size in bytes
maxCaptureSize int // max bytes for captures maxCaptureSize int // max bytes for captures (uncompressed)
} }
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the // newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
@@ -84,7 +116,7 @@ func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int)
logger: logger, logger: logger,
maxMetrics: maxMetrics, maxMetrics: maxMetrics,
enableCaptures: captureBufferMB > 0, enableCaptures: captureBufferMB > 0,
captures: make(map[int]ReqRespCapture), captures: make(map[int][]byte),
captureOrder: make([]int, 0), captureOrder: make([]int, 0),
captureSize: 0, captureSize: 0,
maxCaptureSize: captureBufferMB * 1024 * 1024, maxCaptureSize: captureBufferMB * 1024 * 1024,
@@ -108,45 +140,80 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
} }
// addCapture adds a new capture to the buffer with size-based eviction. // addCapture adds a new capture to the buffer with size-based eviction.
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize. // Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) { func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
if !mp.enableCaptures { if !mp.enableCaptures {
return return
} }
mp.mu.Lock() compressed, uncompressedBytes, err := compressCapture(&capture)
defer mp.mu.Unlock() if err != nil {
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
captureSize := capture.Size()
if captureSize > mp.maxCaptureSize {
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
return return
} }
// Evict oldest (FIFO) until room available captureSize := len(compressed)
if captureSize > mp.maxCaptureSize {
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
return
}
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
mp.mu.Lock()
defer mp.mu.Unlock()
// Evict oldest (FIFO) until room available for the compressed data
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 { for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
oldestID := mp.captureOrder[0] oldestID := mp.captureOrder[0]
mp.captureOrder = mp.captureOrder[1:] mp.captureOrder = mp.captureOrder[1:]
if evicted, exists := mp.captures[oldestID]; exists { if evicted, exists := mp.captures[oldestID]; exists {
mp.captureSize -= evicted.Size() l := len(evicted)
mp.captureSize -= l
delete(mp.captures, oldestID) delete(mp.captures, oldestID)
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
} }
} }
mp.captures[capture.ID] = capture mp.captures[capture.ID] = compressed
mp.captureOrder = append(mp.captureOrder, capture.ID) mp.captureOrder = append(mp.captureOrder, capture.ID)
mp.captureSize += captureSize mp.captureSize += captureSize
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
} }
// getCaptureByID returns a capture by its ID, or nil if not found. // getCompressedBytes returns the raw compressed bytes for a capture by ID.
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture { func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
mp.mu.RLock() mp.mu.RLock()
defer mp.mu.RUnlock() defer mp.mu.RUnlock()
if capture, exists := mp.captures[id]; exists { data, exists := mp.captures[id]
return &capture return data, exists
}
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
// If decompress=false, returns the raw zstd-compressed bytes.
// Returns nil if the capture is not found.
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
mp.mu.RLock()
defer mp.mu.RUnlock()
data, exists := mp.captures[id]
if !exists {
return nil
} }
return nil
if !decompress {
return data
}
decompressed, err := decompressCapture(data)
if err != nil {
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
return nil
}
return decompressed
} }
// getMetrics returns a copy of the current metrics // getMetrics returns a copy of the current metrics
@@ -290,8 +357,8 @@ func (mp *metricsMonitor) wrapHandler(
RespHeaders: respHeaders, RespHeaders: respHeaders,
RespBody: body, RespBody: body,
} }
// Only set HasCapture if the capture will actually be stored (not too large) compressed, _, err := compressCapture(capture)
if capture.Size() <= mp.maxCaptureSize { if err == nil && len(compressed) <= mp.maxCaptureSize {
tm.HasCapture = true tm.HasCapture = true
} }
} }
+83 -42
View File
@@ -5,6 +5,7 @@ import (
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
"encoding/json" "encoding/json"
"math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync" "sync"
@@ -953,28 +954,27 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
}) })
} }
func TestReqRespCapture_Size(t *testing.T) { func TestReqRespCapture_CompressedSize(t *testing.T) {
t.Run("calculates size correctly", func(t *testing.T) { t.Run("compressed size is smaller than uncompressed", func(t *testing.T) {
capture := ReqRespCapture{ capture := ReqRespCapture{
ID: 1, ID: 1,
ReqPath: "/v1/chat/completions", // 20 bytes ReqPath: "/v1/chat/completions",
ReqHeaders: map[string]string{ ReqBody: []byte(`{"model":"test","prompt":"hello world this is a test request body that is reasonably long"}`),
"Content-Type": "application/json", // 12 + 16 = 28 RespBody: []byte(`{"id":"resp-123","object":"chat.completion","created":1234567890,"model":"test-model","choices":[{"index":0,"message":{"role":"assistant","content":"This is a test response body with some meaningful content to compress"}},{"index":1,"message":{"role":"user","content":"Another message here"}}]}`),
},
ReqBody: []byte("request body"), // 12 bytes
RespHeaders: map[string]string{
"X-Test": "value", // 6 + 5 = 11
},
RespBody: []byte("response body"), // 13 bytes
} }
// Expected: 20 + 12 + 13 + 28 + 11 = 84 compressed, uncompressed, err := compressCapture(&capture)
assert.Equal(t, 84, capture.Size()) assert.NoError(t, err)
assert.Greater(t, uncompressed, 0)
assert.True(t, len(compressed) < uncompressed, "compressed (%d bytes) should be smaller than uncompressed JSON (%d bytes)", len(compressed), uncompressed)
}) })
t.Run("handles empty capture", func(t *testing.T) { t.Run("empty capture produces compressed output", func(t *testing.T) {
capture := ReqRespCapture{} capture := ReqRespCapture{}
assert.Equal(t, 0, capture.Size()) compressed, _, err := compressCapture(&capture)
assert.NoError(t, err)
assert.NotNil(t, compressed)
assert.True(t, len(compressed) > 0)
}) })
} }
@@ -989,7 +989,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
mm.addCapture(capture) mm.addCapture(capture)
// Should not store capture // Should not store capture
assert.Nil(t, mm.getCaptureByID(0)) assert.Nil(t, mm.getCaptureByID(0, false))
}) })
t.Run("adds capture when enabled", func(t *testing.T) { t.Run("adds capture when enabled", func(t *testing.T) {
@@ -1002,41 +1002,55 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
} }
mm.addCapture(capture) mm.addCapture(capture)
retrieved := mm.getCaptureByID(0) retrieved := mm.getCaptureByID(0, true)
assert.NotNil(t, retrieved) assert.NotNil(t, retrieved)
assert.Equal(t, 0, retrieved.ID)
assert.Equal(t, []byte("test request"), retrieved.ReqBody) var decoded ReqRespCapture
assert.Equal(t, []byte("test response"), retrieved.RespBody) err := json.Unmarshal(retrieved, &decoded)
assert.NoError(t, err)
assert.Equal(t, 0, decoded.ID)
assert.Equal(t, []byte("test request"), decoded.ReqBody)
assert.Equal(t, []byte("test response"), decoded.RespBody)
}) })
t.Run("evicts oldest when exceeding max size", func(t *testing.T) { t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100 // Set small limit for test // Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
mm.maxCaptureSize = 450
// Add captures that will exceed the limit // Use random-looking data that doesn't compress well with zstd
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)} rng := rand.New(rand.NewSource(42))
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)} capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 80)}
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)} rng.Read(capture1.ReqBody)
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 80)}
rng.Read(capture2.ReqBody)
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 80)}
rng.Read(capture3.ReqBody)
mm.addCapture(capture1) mm.addCapture(capture1)
mm.addCapture(capture2) mm.addCapture(capture2)
// Adding capture3 should evict capture1 // Adding capture3 should evict capture1
mm.addCapture(capture3) mm.addCapture(capture3)
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted") assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted")
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist") retrieved := mm.getCaptureByID(1, true)
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist") assert.NotNil(t, retrieved, "capture 1 should exist")
retrieved = mm.getCaptureByID(2, true)
assert.NotNil(t, retrieved, "capture 2 should exist")
}) })
t.Run("skips capture larger than max size", func(t *testing.T) { t.Run("skips capture larger than max size", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
mm.maxCaptureSize = 100 mm.maxCaptureSize = 100
// Add a capture larger than max // Use random data that doesn't compress well to create an oversized capture
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)} rng := rand.New(rand.NewSource(99))
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 300)}
rng.Read(largeCapture.ReqBody)
mm.addCapture(largeCapture) mm.addCapture(largeCapture)
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored") assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored")
}) })
} }
@@ -1044,21 +1058,44 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
t.Run("returns nil for non-existent ID", func(t *testing.T) { t.Run("returns nil for non-existent ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
assert.Nil(t, mm.getCaptureByID(999)) assert.Nil(t, mm.getCaptureByID(999, false))
}) })
t.Run("returns capture by ID", func(t *testing.T) { t.Run("returns decompressed capture by ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5) mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{ capture := ReqRespCapture{
ID: 42, ID: 42,
ReqBody: []byte("test"), ReqBody: []byte("test request"),
RespBody: []byte("test response"),
} }
mm.addCapture(capture) mm.addCapture(capture)
retrieved := mm.getCaptureByID(42) retrieved := mm.getCaptureByID(42, true)
assert.NotNil(t, retrieved) assert.NotNil(t, retrieved)
assert.Equal(t, 42, retrieved.ID)
var decoded ReqRespCapture
err := json.Unmarshal(retrieved, &decoded)
assert.NoError(t, err)
assert.Equal(t, 42, decoded.ID)
assert.Equal(t, []byte("test request"), decoded.ReqBody)
assert.Equal(t, []byte("test response"), decoded.RespBody)
})
t.Run("returns compressed bytes when decompress=false", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 5)
capture := ReqRespCapture{
ID: 42,
ReqBody: []byte("test request body"),
RespBody: []byte("test response body"),
}
mm.addCapture(capture)
compressed := mm.getCaptureByID(42, false)
assert.NotNil(t, compressed)
// Compressed data should not be valid JSON (it's zstd-compressed)
assert.False(t, gjson.ValidBytes(compressed))
}) })
} }
@@ -1135,9 +1172,13 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
metricID := metrics[0].ID metricID := metrics[0].ID
// Check capture was stored with same ID // Check capture was stored with same ID (decompressed)
capture := mm.getCaptureByID(metricID) captureData := mm.getCaptureByID(metricID, true)
assert.NotNil(t, capture) assert.NotNil(t, captureData)
var capture ReqRespCapture
err = json.Unmarshal(captureData, &capture)
assert.NoError(t, err)
assert.Equal(t, metricID, capture.ID) assert.Equal(t, metricID, capture.ID)
assert.Equal(t, []byte(requestBody), capture.ReqBody) assert.Equal(t, []byte(requestBody), capture.ReqBody)
assert.Equal(t, []byte(responseBody), capture.RespBody) assert.Equal(t, []byte(responseBody), capture.RespBody)
@@ -1173,7 +1214,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
assert.Equal(t, 1, len(metrics)) assert.Equal(t, 1, len(metrics))
// But no capture // But no capture
capture := mm.getCaptureByID(metrics[0].ID) capture := mm.getCaptureByID(metrics[0].ID, false)
assert.Nil(t, capture) assert.Nil(t, capture)
}) })
} }
+62 -4
View File
@@ -77,6 +77,9 @@ type Process struct {
// used for testing to override the default value // used for testing to override the default value
gracefulStopTimeout time.Duration gracefulStopTimeout time.Duration
// used for testing to bypass subprocess and reverse proxy
testHandler http.Handler
// track the number of failed starts // track the number of failed starts
failedStartCount int failedStartCount int
} }
@@ -236,6 +239,49 @@ func (p *Process) forceState(newState ProcessState) {
// at any time. // at any time.
func (p *Process) start() error { func (p *Process) start() error {
// test-only fast path: skip subprocess, health check, and TTL goroutine
if p.testHandler != nil {
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
if err == ErrExpectedStateMismatch {
if curState == StateStarting {
p.waitStarting.Wait()
curState = p.CurrentState()
if curState == StateReady {
return nil
}
return fmt.Errorf("process was already starting but wound up in state %v", curState)
}
return fmt.Errorf("process was in state %v when start() was called", curState)
}
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
}
defer p.waitStarting.Done()
// Mimic the real stop path: cancelUpstream transitions
// StateStopping -> StateStopped and closes cmdWaitChan,
// matching what waitForCmd does for real subprocesses.
ch := make(chan struct{})
p.cmdMutex.Lock()
p.cancelUpstream = func() {
if curState := p.CurrentState(); curState == StateStopping {
if _, err := p.swapState(StateStopping, StateStopped); err != nil {
p.forceState(StateStopped)
}
} else {
p.forceState(StateStopped)
}
close(ch)
}
p.cmdWaitChan = ch
p.cmdMutex.Unlock()
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
}
p.failedStartCount = 0
return nil
}
if p.config.Proxy == "" { if p.config.Proxy == "" {
return fmt.Errorf("can not start(), upstream proxy missing") return fmt.Errorf("can not start(), upstream proxy missing")
} }
@@ -386,7 +432,10 @@ func (p *Process) start() error {
// Stop will wait for inflight requests to complete before stopping the process. // Stop will wait for inflight requests to complete before stopping the process.
func (p *Process) Stop() { func (p *Process) Stop() {
// guard to prevent multiple goroutines from stopping
if !isValidTransition(p.CurrentState(), StateStopping) { if !isValidTransition(p.CurrentState(), StateStopping) {
p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
return return
} }
@@ -399,13 +448,17 @@ func (p *Process) Stop() {
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM. // StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL. // If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() { func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
// guard to prevent multiple goroutines from stopping the process
enterState := p.CurrentState()
if !isValidTransition(enterState, StateStopping) {
p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
return return
} }
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState()) p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState)
if curState, err := p.swapState(StateReady, StateStopping); err != nil { if curState, err := p.swapState(enterState, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState)
return return
} }
@@ -577,6 +630,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if !srw.waitForCompletion(completionTimeout) { if !srw.waitForCompletion(completionTimeout) {
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout) p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
} }
}
if p.testHandler != nil {
p.testHandler.ServeHTTP(w, r)
} else if srw != nil {
p.reverseProxy.ServeHTTP(srw, r) p.reverseProxy.ServeHTTP(srw, r)
} else { } else {
p.reverseProxy.ServeHTTP(w, r) p.reverseProxy.ServeHTTP(w, r)
+36
View File
@@ -24,6 +24,22 @@ type ProcessGroup struct {
// map of current processes // map of current processes
processes map[string]*Process processes map[string]*Process
lastUsedProcess string lastUsedProcess string
// inflight tracks fast-path requests (requests for the already-selected
// model in a swap group). Fast-path requests Add(1) while holding pg.Lock
// and Done() on completion; a concurrent swap request calls inflight.Wait()
// under pg.Lock before stopping the current process. Without this tracking,
// a fast-path request that has released pg.Lock but has not yet called
// Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be
// killed mid-request.
inflight sync.WaitGroup
// testDelayFastPath is a test-only hook that, when non-nil, is invoked in
// the fast path after pg.Lock is released but before the request is
// dispatched to Process.ProxyRequest. Tests use it to park a fast-path
// request at the exact race window to deterministically reproduce the
// fast-path vs swap race.
testDelayFastPath func()
} }
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
@@ -64,6 +80,13 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
pg.Lock() pg.Lock()
if pg.lastUsedProcess != modelID { if pg.lastUsedProcess != modelID {
// Wait for in-flight fast-path requests to drain before stopping
// the previous process. Without this, a fast-path request that has
// released pg.Lock but has not yet incremented
// Process.inFlightRequests races with Stop() and can be killed
// mid-request.
pg.inflight.Wait()
// is there something already running? // is there something already running?
if pg.lastUsedProcess != "" { if pg.lastUsedProcess != "" {
pg.processes[pg.lastUsedProcess].Stop() pg.processes[pg.lastUsedProcess].Stop()
@@ -78,7 +101,16 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
pg.Unlock() pg.Unlock()
return nil return nil
} }
// Fast path: register this request in inflight before releasing
// pg.Lock so a concurrent swap will wait for it to complete.
pg.inflight.Add(1)
defer pg.inflight.Done()
pg.Unlock() pg.Unlock()
if pg.testDelayFastPath != nil {
pg.testDelayFastPath()
}
} }
pg.processes[modelID].ProxyRequest(writer, request) pg.processes[modelID].ProxyRequest(writer, request)
@@ -123,6 +155,10 @@ func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock() pg.Lock()
defer pg.Unlock() defer pg.Unlock()
if strategy != StopImmediately {
pg.inflight.Wait()
}
if len(pg.processes) == 0 { if len(pg.processes) == 0 {
return return
} }
+226
View File
@@ -4,11 +4,14 @@ import (
"bytes" "bytes"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"runtime"
"sync" "sync"
"testing" "testing"
"time"
"github.com/mostlygeek/llama-swap/proxy/config" "github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
@@ -95,6 +98,229 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
wg.Wait() wg.Wait()
} }
// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap
// request cannot stop the current process while a fast-path request (for the
// already-selected model) is in flight. Without ProcessGroup-level inflight
// tracking, a fast-path request that has released pg.Lock but has not yet
// incremented Process.inFlightRequests races with Stop()'s Wait() and the
// process is killed mid-request.
func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
cfg := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2"},
},
},
})
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
defer pg.StopProcesses(StopImmediately)
// Bypass real subprocesses so the test is fast and deterministic.
pg.processes["model1"].testHandler = newTestHandler("model1")
pg.processes["model2"].testHandler = newTestHandler("model2")
// Prime: run a request through model1 via the swap path so that
// lastUsedProcess == "model1" and subsequent model1 requests take the
// fast path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
require.Equal(t, StateStopped, pg.processes["model2"].CurrentState())
// Fast-path hook: signal arrival at the race window, then wait for
// release. This parks R2 deterministically at the point where pg.Lock
// has been released but Process.inFlightRequests has not yet been
// incremented — the exact window the race exploits.
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
pg.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
// R2: fast-path request for model1. Will pause at the test hook.
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
}()
// Deterministically wait for R2 to reach the race window.
<-r2Reached
// R3: swap request for model2. Must wait for R2 to finish before touching
// model1, otherwise model1 gets killed mid-request.
r3Done := make(chan struct{})
w3 := httptest.NewRecorder()
go func() {
defer close(r3Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model2", w3, req))
}()
// Spin until R3 has acquired pg.Lock and entered the swap critical
// section. In the fixed code, R3 then blocks on pg.inflight.Wait() while
// still holding the lock, so TryLock keeps failing.
for pg.TryLock() {
pg.Unlock()
runtime.Gosched()
}
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
// state. In the fixed code, R3 is blocked on pg.inflight.Wait() and
// nothing changes, so we wait the full window. In the buggy code, R3
// will Stop() model1 and start serving via model2 within microseconds —
// we exit early once the mutation is observable.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if pg.processes["model1"].CurrentState() != StateReady ||
pg.processes["model2"].CurrentState() != StateStopped {
break
}
done := false
select {
case <-r3Done:
done = true
default:
}
if done {
break
}
runtime.Gosched()
}
// Invariant: R3 must be blocked while R2 is still in flight.
select {
case <-r3Done:
t.Fatal("swap completed while fast-path request was still in flight — race not prevented")
default:
}
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
"model1 must stay Ready while a fast-path request is in flight")
assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(),
"model2 must not be started until R2 finishes and model1 is swapped out")
// Release R2 and let both requests finish.
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
assert.Equal(t, http.StatusOK, w3.Code)
assert.Contains(t, w3.Body.String(), "model2")
}
// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses
// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a
// process while a fast-path ProxyRequest is in the [pg.Unlock,
// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in
// StopProcesses, the external caller bypasses the inflight guard and kills the
// process mid-request.
func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) {
cfg := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2"},
},
},
})
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
defer pg.StopProcesses(StopImmediately)
pg.processes["model1"].testHandler = newTestHandler("model1")
pg.processes["model2"].testHandler = newTestHandler("model2")
// Prime: model1 is active so subsequent model1 requests take the fast path.
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
primeW := httptest.NewRecorder()
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
require.Equal(t, http.StatusOK, primeW.Code)
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
// Park a fast-path request at the race window.
r2Reached := make(chan struct{})
r2Release := make(chan struct{})
pg.testDelayFastPath = func() {
close(r2Reached)
<-r2Release
}
r2Done := make(chan struct{})
w2 := httptest.NewRecorder()
go func() {
defer close(r2Done)
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
}()
<-r2Reached
// Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping
// the group while a fast-path request is in flight.
r3Done := make(chan struct{})
go func() {
defer close(r3Done)
pg.StopProcesses(StopWaitForInflightRequest)
}()
// Spin until StopProcesses has acquired pg.Lock.
for pg.TryLock() {
pg.Unlock()
runtime.Gosched()
}
// Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait()
// and model1 stays Ready. In the buggy code it proceeds immediately and
// kills model1.
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if pg.processes["model1"].CurrentState() != StateReady {
break
}
select {
case <-r3Done:
goto done
default:
}
runtime.Gosched()
}
done:
select {
case <-r3Done:
t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented")
default:
}
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
"model1 must stay Ready while a fast-path request is in flight")
close(r2Release)
<-r2Done
<-r3Done
assert.Equal(t, http.StatusOK, w2.Code)
assert.Contains(t, w2.Body.String(), "model1")
}
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
defer pg.StopProcesses(StopWaitForInflightRequest) defer pg.StopProcesses(StopWaitForInflightRequest)
+18 -3
View File
@@ -290,11 +290,26 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
return return
} }
capture := pm.metricsMonitor.getCaptureByID(id) data, exists := pm.metricsMonitor.getCompressedBytes(id)
if capture == nil { if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"}) c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
return return
} }
c.JSON(http.StatusOK, capture) c.Header("Vary", "Accept-Encoding")
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd")
if hasZstd {
c.Header("Content-Encoding", "zstd")
c.Data(http.StatusOK, "application/json", data)
} else {
decompressed, err := decompressCapture(data)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
return
}
c.Data(http.StatusOK, "application/json", decompressed)
}
} }
+322 -339
View File
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,72 @@
<script lang="ts">
import { inFlightRequests, metrics } from "../stores/api";
import { persistentStore } from "../stores/persistent";
import { calculateHistogramData } from "../lib/histogram";
import TokenHistogram from "./TokenHistogram.svelte";
const nf = new Intl.NumberFormat();
const histogramCollapsed = persistentStore<boolean>("activity-histogram-collapsed", false);
let stats = $derived.by(() => {
const totalRequests = $metrics.length;
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
const tokensPerSecond = $metrics
.filter((m) => m.tokens_per_second > 0)
.map((m) => m.tokens_per_second);
const histogramData = tokensPerSecond.length > 0
? calculateHistogramData(tokensPerSecond, { minBins: 20, maxBins: 80, binScaling: 3 })
: null;
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
histogramData,
};
});
</script>
<div class="card">
<button
class="flex items-center gap-1 px-4 pt-3 text-xs font-medium text-gray-500 dark:text-gray-400 hover:text-gray-700 dark:hover:text-gray-200 transition-colors"
onclick={() => $histogramCollapsed = !$histogramCollapsed}
>
<svg
class="w-3 h-3 transition-transform"
style="transform: rotate({$histogramCollapsed ? -90 : 0}deg)"
viewBox="0 0 16 16"
fill="currentColor"
>
<path d="M4.5 6l3.5 4 3.5-4H4.5z" />
</svg>
Tokens/sec Distribution
</button>
{#if !$histogramCollapsed}
{#if stats.histogramData}
<TokenHistogram data={stats.histogramData} />
{:else}
<div class="px-4 py-6 text-center text-sm text-gray-500 dark:text-gray-400">
No token speed data yet
</div>
{/if}
{/if}
<div class="grid grid-cols-3 gap-x-6 gap-y-1 px-4 pb-3 text-sm">
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Requests</div>
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Processed</div>
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Generated</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalRequests)}</span> completed,
<span class="font-semibold">{nf.format(stats.inFlightRequests)}</span> waiting
</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalInputTokens)}</span> tokens
</div>
<div class="text-sm text-gray-700 dark:text-gray-300">
<span class="font-semibold">{nf.format(stats.totalOutputTokens)}</span> tokens
</div>
</div>
</div>
-167
View File
@@ -1,167 +0,0 @@
<script lang="ts">
import { inFlightRequests, metrics } from "../stores/api";
import TokenHistogram from "./TokenHistogram.svelte";
interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
let stats = $derived.by(() => {
const totalRequests = $metrics.length;
if (totalRequests === 0) {
return {
totalRequests: 0,
totalInputTokens: 0,
totalOutputTokens: 0,
inFlightRequests: $inFlightRequests,
tokenStats: { p99: "0", p95: "0", p50: "0" },
histogramData: null,
};
}
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
// Calculate token statistics using output_tokens and duration_ms
const validMetrics = $metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
if (validMetrics.length === 0) {
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
tokenStats: { p99: "0", p95: "0", p50: "0" },
histogramData: null,
};
}
// Calculate tokens/second for each valid metric
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
// Sort for percentile calculation
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
// Create histogram data
const min = Math.min(...tokensPerSecond);
const max = Math.max(...tokensPerSecond);
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5)));
const binSize = (max - min) / binCount;
const bins = Array(binCount).fill(0);
tokensPerSecond.forEach((value) => {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
});
const histogramData: HistogramData = {
bins,
min,
max,
binSize,
p99,
p95,
p50,
};
return {
totalRequests,
totalInputTokens,
totalOutputTokens,
inFlightRequests: $inFlightRequests,
tokenStats: {
p99: p99.toFixed(2),
p95: p95.toFixed(2),
p50: p50.toFixed(2),
},
histogramData,
};
});
const nf = new Intl.NumberFormat();
</script>
<div class="card">
<div class="rounded-lg overflow-hidden border border-card-border-inner">
<table class="min-w-full divide-y divide-card-border-inner">
<thead class="bg-secondary">
<tr>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">Requests</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Processed
</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Generated
</th>
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
Token Stats (tokens/sec)
</th>
</tr>
</thead>
<tbody class="bg-surface divide-y divide-card-border-inner">
<tr class="hover:bg-secondary">
<td class="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">
<div class="flex flex-col gap-1">
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Completed: {nf.format(stats.totalRequests)}</span>
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Waiting: {nf.format(stats.inFlightRequests)}</span>
</div>
</td>
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{nf.format(stats.totalInputTokens)}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{nf.format(stats.totalOutputTokens)}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
</div>
</td>
<td class="px-4 py-4 border-l border-gray-200 dark:border-white/10">
<div class="space-y-3">
<div class="grid grid-cols-3 gap-2 items-center">
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P50</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p50}
</div>
</div>
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P95</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p95}
</div>
</div>
<div class="text-center">
<div class="text-xs text-gray-500 dark:text-gray-400">P99</div>
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
{stats.tokenStats.p99}
</div>
</div>
</div>
{#if stats.histogramData}
<TokenHistogram data={stats.histogramData} />
{/if}
</div>
</td>
</tr>
</tbody>
</table>
</div>
</div>
+5 -21
View File
@@ -1,23 +1,11 @@
<script lang="ts"> <script lang="ts">
interface HistogramData { import type { HistogramData } from "../lib/types";
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
interface Props { let { data }: { data: HistogramData } = $props();
data: HistogramData;
}
let { data }: Props = $props(); const height = 55;
const padding = { top: 5, right: 45, bottom: 15, left: 45 };
const height = 120; const viewBoxWidth = 1200;
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
const viewBoxWidth = 600;
const chartWidth = viewBoxWidth - padding.left - padding.right; const chartWidth = viewBoxWidth - padding.left - padding.right;
const chartHeight = height - padding.top - padding.bottom; const chartHeight = height - padding.top - padding.bottom;
@@ -121,9 +109,5 @@
{data.max.toFixed(1)} {data.max.toFixed(1)}
</text> </text>
<!-- X-axis label -->
<text x={padding.left + chartWidth / 2} y={height - 2} font-size="10" fill="currentColor" opacity="0.6" text-anchor="middle">
Tokens/Second Distribution
</text>
</svg> </svg>
</div> </div>
+163
View File
@@ -0,0 +1,163 @@
import { describe, it, expect } from "vitest";
import { calculateHistogramData } from "./histogram";
describe("calculateHistogramData", () => {
describe("edge cases", () => {
it("returns null for empty input", () => {
expect(calculateHistogramData([])).toBeNull();
});
it("handles single value", () => {
const result = calculateHistogramData([42]);
expect(result).not.toBeNull();
expect(result!.bins).toEqual([1]);
expect(result!.min).toBe(42);
expect(result!.max).toBe(42);
expect(result!.binSize).toBe(0);
expect(result!.p50).toBe(42);
expect(result!.p95).toBe(42);
expect(result!.p99).toBe(42);
});
it("handles all identical values", () => {
const result = calculateHistogramData([10, 10, 10, 10, 10]);
expect(result).not.toBeNull();
expect(result!.bins).toEqual([5]);
expect(result!.min).toBe(10);
expect(result!.max).toBe(10);
expect(result!.binSize).toBe(0);
});
it("handles two distinct values", () => {
const result = calculateHistogramData([10, 20]);
expect(result).not.toBeNull();
expect(result!.min).toBe(10);
expect(result!.max).toBe(20);
expect(result!.p50).toBe(15);
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(2);
});
});
describe("bin distribution", () => {
it("bins sum to total number of values", () => {
const values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(values.length);
});
it("distributes uniform values across bins", () => {
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.bins.length).toBe(20);
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(100);
});
it("places values in correct bins", () => {
const values = [1, 1, 1, 5, 5, 9, 9, 9];
const result = calculateHistogramData(values, { minBins: 3, maxBins: 3, binScaling: 1 });
expect(result).not.toBeNull();
expect(result!.bins.length).toBe(3);
expect(result!.bins.reduce((s, b) => s + b, 0)).toBe(8);
});
it("handles skewed distribution", () => {
const values = [1, 1, 1, 1, 1, 100];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
const binSum = result!.bins.reduce((s, b) => s + b, 0);
expect(binSum).toBe(6);
});
});
describe("percentiles", () => {
it("calculates correct p50 for even-length array", () => {
const values = [1, 2, 3, 4];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBe(2.5);
});
it("calculates correct p50 for odd-length array", () => {
const values = [1, 2, 3, 4, 5];
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBe(3);
});
it("calculates p99 with interpolation", () => {
const values = Array.from({ length: 100 }, (_, i) => i + 1);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p99).toBeCloseTo(99.01);
});
it("calculates p95 with interpolation", () => {
const values = Array.from({ length: 100 }, (_, i) => i + 1);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p95).toBeCloseTo(95.05);
});
it("percentiles are monotonically increasing", () => {
const values = Array.from({ length: 200 }, () => Math.random() * 100);
const result = calculateHistogramData(values);
expect(result).not.toBeNull();
expect(result!.p50).toBeLessThanOrEqual(result!.p95);
expect(result!.p95).toBeLessThanOrEqual(result!.p99);
});
});
describe("bin count adaptation", () => {
it("uses minimum bins for small datasets", () => {
const values = Array.from({ length: 20 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result!.bins.length).toBe(10);
});
it("scales bins with dataset size", () => {
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result!.bins.length).toBe(20);
});
it("caps bins at maximum", () => {
const values = Array.from({ length: 200 }, (_, i) => i);
const result = calculateHistogramData(values);
expect(result!.bins.length).toBe(30);
});
it("respects custom options", () => {
const values = Array.from({ length: 100 }, (_, i) => i);
const result = calculateHistogramData(values, { minBins: 5, maxBins: 10, binScaling: 2 });
expect(result!.bins.length).toBe(10);
});
});
describe("min and max", () => {
it("correctly identifies min and max", () => {
const values = [5, 3, 8, 1, 9, 2];
const result = calculateHistogramData(values);
expect(result!.min).toBe(1);
expect(result!.max).toBe(9);
});
it("handles negative values", () => {
const values = [-10, -5, 0, 5, 10];
const result = calculateHistogramData(values);
expect(result!.min).toBe(-10);
expect(result!.max).toBe(10);
});
it("handles floating point values", () => {
const values = [1.5, 2.7, 3.14, 0.5, 4.99];
const result = calculateHistogramData(values);
expect(result!.min).toBe(0.5);
expect(result!.max).toBe(4.99);
});
});
});
+72
View File
@@ -0,0 +1,72 @@
import type { HistogramData } from "./types";
export interface HistogramOptions {
minBins?: number;
maxBins?: number;
binScaling?: number;
}
const DEFAULT_OPTIONS: HistogramOptions = {
minBins: 10,
maxBins: 30,
binScaling: 5,
};
function percentile(sorted: number[], p: number): number {
if (sorted.length === 0) return 0;
if (sorted.length === 1) return sorted[0];
const rank = (p / 100) * (sorted.length - 1);
const lower = Math.floor(rank);
const upper = Math.ceil(rank);
const fraction = rank - lower;
return sorted[lower] + fraction * (sorted[upper] - sorted[lower]);
}
export function calculateHistogramData(
values: number[],
options: HistogramOptions = DEFAULT_OPTIONS,
): HistogramData | null {
if (values.length === 0) return null;
const sorted = [...values].sort((a, b) => a - b);
const min = sorted[0];
const max = sorted[sorted.length - 1];
const p50 = percentile(sorted, 50);
const p95 = percentile(sorted, 95);
const p99 = percentile(sorted, 99);
if (min === max) {
return {
bins: [values.length],
min,
max,
binSize: 0,
p50,
p95,
p99,
};
}
const { minBins = 10, maxBins = 30, binScaling = 5 } = options;
const binCount = Math.min(maxBins, Math.max(minBins, Math.floor(values.length / binScaling)));
const binSize = (max - min) / binCount;
const bins = new Array(binCount).fill(0);
for (const value of values) {
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
bins[binIndex]++;
}
return {
bins,
min,
max,
binSize,
p50,
p95,
p99,
};
}
+10
View File
@@ -48,6 +48,16 @@ export interface APIEventEnvelope {
data: string; data: string;
} }
export interface HistogramData {
bins: number[];
min: number;
max: number;
binSize: number;
p99: number;
p95: number;
p50: number;
}
export interface VersionInfo { export interface VersionInfo {
build_date: string; build_date: string;
commit: string; commit: string;
+35 -29
View File
@@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import { metrics, getCapture } from "../stores/api"; import { metrics, getCapture } from "../stores/api";
import ActivityStats from "../components/ActivityStats.svelte";
import Tooltip from "../components/Tooltip.svelte"; import Tooltip from "../components/Tooltip.svelte";
import CaptureDialog from "../components/CaptureDialog.svelte"; import CaptureDialog from "../components/CaptureDialog.svelte";
import type { ReqRespCapture } from "../lib/types"; import type { ReqRespCapture } from "../lib/types";
@@ -63,33 +64,38 @@
<div class="p-2"> <div class="p-2">
<h1 class="text-2xl font-bold">Activity</h1> <h1 class="text-2xl font-bold">Activity</h1>
<div class="mt-4 mb-4">
<ActivityStats />
</div>
{#if $metrics.length === 0} <div class="card overflow-auto">
<div class="text-center py-8"> <table class="min-w-full divide-y">
<p class="text-gray-600">No metrics data available</p> <thead class="border-gray-200 dark:border-white/10">
</div> <tr class="text-left text-xs uppercase tracking-wider">
{:else} <th class="px-6 py-3">ID</th>
<div class="card overflow-auto"> <th class="px-6 py-3">Time</th>
<table class="min-w-full divide-y"> <th class="px-6 py-3">Model</th>
<thead class="border-gray-200 dark:border-white/10"> <th class="px-6 py-3">
<tr class="text-left text-xs uppercase tracking-wider"> Cached <Tooltip content="prompt tokens from cache" />
<th class="px-6 py-3">ID</th> </th>
<th class="px-6 py-3">Time</th> <th class="px-6 py-3">
<th class="px-6 py-3">Model</th> Prompt <Tooltip content="new prompt tokens processed" />
<th class="px-6 py-3"> </th>
Cached <Tooltip content="prompt tokens from cache" /> <th class="px-6 py-3">Generated</th>
</th> <th class="px-6 py-3">Prompt Processing</th>
<th class="px-6 py-3"> <th class="px-6 py-3">Generation Speed</th>
Prompt <Tooltip content="new prompt tokens processed" /> <th class="px-6 py-3">Duration</th>
</th> <th class="px-6 py-3">Capture</th>
<th class="px-6 py-3">Generated</th> </tr>
<th class="px-6 py-3">Prompt Processing</th> </thead>
<th class="px-6 py-3">Generation Speed</th> <tbody class="divide-y">
<th class="px-6 py-3">Duration</th> {#if sortedMetrics.length === 0}
<th class="px-6 py-3">Capture</th> <tr>
<td colspan="10" class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
No activity recorded
</td>
</tr> </tr>
</thead> {:else}
<tbody class="divide-y">
{#each sortedMetrics as metric (metric.id)} {#each sortedMetrics as metric (metric.id)}
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10"> <tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
<td class="px-4 py-4">{metric.id + 1}</td> <td class="px-4 py-4">{metric.id + 1}</td>
@@ -116,10 +122,10 @@
</td> </td>
</tr> </tr>
{/each} {/each}
</tbody> {/if}
</table> </tbody>
</div> </table>
{/if} </div>
</div> </div>
<CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} /> <CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} />
+1 -9
View File
@@ -2,7 +2,6 @@
import { isNarrow } from "../stores/theme"; import { isNarrow } from "../stores/theme";
import { upstreamLogs } from "../stores/api"; import { upstreamLogs } from "../stores/api";
import ModelsPanel from "../components/ModelsPanel.svelte"; import ModelsPanel from "../components/ModelsPanel.svelte";
import StatsPanel from "../components/StatsPanel.svelte";
import LogPanel from "../components/LogPanel.svelte"; import LogPanel from "../components/LogPanel.svelte";
import ResizablePanels from "../components/ResizablePanels.svelte"; import ResizablePanels from "../components/ResizablePanels.svelte";
@@ -14,13 +13,6 @@
<ModelsPanel /> <ModelsPanel />
{/snippet} {/snippet}
{#snippet rightPanel()} {#snippet rightPanel()}
<div class="flex flex-col h-full space-y-4"> <LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
{#if direction === "horizontal"}
<StatsPanel />
{/if}
<div class="flex-1 min-h-0">
<LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
</div>
</div>
{/snippet} {/snippet}
</ResizablePanels> </ResizablePanels>