Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e8d4384cd2 | |||
| ce28485be2 | |||
| 3cd7837b1f | |||
| 0b31ccacc1 | |||
| 5938dbee8f | |||
| 66639e83f7 | |||
| 625b296720 | |||
| 231e62291c | |||
| 57ac666598 | |||
| 69728301f5 | |||
| c176fa70f1 | |||
| 5e3c646829 | |||
| c3f0d43e6e | |||
| f6cf9f5844 | |||
| 121fd93ad8 |
@@ -19,9 +19,6 @@ jobs:
|
|||||||
|
|
||||||
run-tests:
|
run-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
working-directory: ui-svelte
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
@@ -32,11 +29,5 @@ jobs:
|
|||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
cache-dependency-path: ui-svelte/package-lock.json
|
cache-dependency-path: ui-svelte/package-lock.json
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Run UI tests
|
||||||
run: npm ci
|
run: make test-ui
|
||||||
|
|
||||||
- name: Type check
|
|
||||||
run: npm run check
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: npm test
|
|
||||||
|
|||||||
@@ -36,6 +36,11 @@ on:
|
|||||||
type: boolean
|
type: boolean
|
||||||
required: false
|
required: false
|
||||||
default: true
|
default: true
|
||||||
|
push_to_ghcr:
|
||||||
|
description: "Push images to ghcr.io"
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: true
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
@@ -116,7 +121,7 @@ jobs:
|
|||||||
docker/unified/build-image.sh --${{ matrix.backend }}
|
docker/unified/build-image.sh --${{ matrix.backend }}
|
||||||
|
|
||||||
- name: Push to GitHub Container Registry
|
- name: Push to GitHub Container Registry
|
||||||
if: ${{ !env.ACT }}
|
if: ${{ !env.ACT && (github.event_name == 'schedule' || inputs.push_to_ghcr == true) }}
|
||||||
run: |
|
run: |
|
||||||
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
|
BASE_TAG="ghcr.io/mostlygeek/llama-swap:unified-${{ matrix.backend }}"
|
||||||
DATE_TAG=$(date -u +%Y-%m-%d)
|
DATE_TAG=$(date -u +%Y-%m-%d)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
|||||||
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
|
- Run `gofmt -l .` before committing to verify formatting. Fix any reported files with `gofmt -w <file>`.
|
||||||
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
- Use `make test-dev` after running new tests for a quick over all test run. This runs `go test` and `staticcheck`. Fix any static checking errors. Use this only when changes are made to any code under the `proxy/` directory
|
||||||
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
- Use `make test-all` before completing work. This includes long running concurrency tests.
|
||||||
|
- Use `make test-ui` after making changes to the UI in ui-svelte/
|
||||||
|
|
||||||
### Commit message example format:
|
### Commit message example format:
|
||||||
|
|
||||||
|
|||||||
@@ -48,10 +48,15 @@ mac: ui
|
|||||||
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64
|
||||||
|
|
||||||
# Build Linux binary
|
# Build Linux binary
|
||||||
linux: ui
|
linux: linux-arm64 linux-amd64
|
||||||
@echo "Building Linux binary..."
|
|
||||||
|
linux-amd64: ui
|
||||||
|
@echo "Building Linux AMD64 binary..."
|
||||||
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64
|
||||||
#GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
|
||||||
|
linux-arm64: ui
|
||||||
|
@echo "Building Linux ARM64 binary..."
|
||||||
|
GOOS=linux GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-arm64
|
||||||
|
|
||||||
# Build Windows binary
|
# Build Windows binary
|
||||||
windows: ui
|
windows: ui
|
||||||
@@ -92,5 +97,9 @@ wol-proxy: $(BUILD_DIR)
|
|||||||
@echo "Building wol-proxy"
|
@echo "Building wol-proxy"
|
||||||
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
go build -o $(BUILD_DIR)/wol-proxy-$(GOOS)-$(GOARCH)-$(shell date +%Y-%m-%d) cmd/wol-proxy/wol-proxy.go
|
||||||
|
|
||||||
|
test-ui:
|
||||||
|
cd ui-svelte && npm ci && npm run check && npm test
|
||||||
|
|
||||||
# Phony targets
|
# Phony targets
|
||||||
.PHONY: all clean ui mac linux windows simple-responder simple-responder-windows test test-all test-dev wol-proxy
|
.PHONY: all clean ui mac windows simple-responder simple-responder-windows test test-all test-dev test-ui wol-proxy
|
||||||
|
.PHONE: linux linux-arm64 linux-amd64
|
||||||
|
|||||||
@@ -0,0 +1,183 @@
|
|||||||
|
# Improve Testability (#655)
|
||||||
|
|
||||||
|
## Current Pain Points
|
||||||
|
|
||||||
|
1. **Tests bypass config loading** - ~80% of tests build `config.Config` structs directly, skipping YAML parsing, env var substitution, macro expansion, and `${PORT}` assignment. Config bugs in those paths go untested.
|
||||||
|
|
||||||
|
2. **simple-responder is everywhere** - Every proxy/routing test launches a real subprocess, waits for health checks (~healthCheckTimeout: 15), and manages process lifecycle just to test HTTP routing. Most of that overhead is wasted.
|
||||||
|
|
||||||
|
3. **Port counter is fragile** - A global `nextTestPort` counter starting at 12000 with a mutex. Parallel tests or leftover processes can collide.
|
||||||
|
|
||||||
|
## Stages
|
||||||
|
|
||||||
|
### Stage 1: YAML-based test config helper
|
||||||
|
|
||||||
|
**Goal:** Tests go through the real `LoadConfigFromReader` path instead of hand-building structs.
|
||||||
|
|
||||||
|
**Effort:** Low | **Impact:** Config bugs caught earlier | **Risk:** None
|
||||||
|
|
||||||
|
Create a test helper in `proxy/helpers_test.go`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// testConfigFromYAML substitutes simple-responder paths and loads through
|
||||||
|
// the real config pipeline (env vars, macros, port assignment, etc.)
|
||||||
|
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||||
|
t.Helper()
|
||||||
|
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||||
|
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Tests would then look like:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
|
config := testConfigFromYAML(t, `
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
logLevel: error
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
|
||||||
|
model2:
|
||||||
|
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model2
|
||||||
|
`)
|
||||||
|
proxy := New(config)
|
||||||
|
// ... same assertions
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why this stage first:** Zero production code changes. Pure test-side refactoring. Can be done incrementally - migrate tests one at a time. Each migrated test now validates the full config pipeline.
|
||||||
|
|
||||||
|
**Scope:** ~20-30 tests in `proxymanager_test.go`, `processgroup_test.go`, `peerproxy_test.go`.
|
||||||
|
|
||||||
|
### Stage 2: Injected test handler (eliminate simple-responder for routing tests)
|
||||||
|
|
||||||
|
**Goal:** Replace simple-responder subprocess launches with an injected `http.Handler` for tests that don't specifically test process lifecycle.
|
||||||
|
|
||||||
|
**Effort:** Medium | **Impact:** 10-100x faster routing tests | **Risk:** Low (additive, no existing code broken)
|
||||||
|
|
||||||
|
Add a `testHandler http.Handler` field to `Process`. When set, `ProxyRequest` delegates directly to this handler instead of going through the reverse proxy. No subprocess, no health checks, no TCP roundtrip.
|
||||||
|
|
||||||
|
**2a. Add testHandler to Process:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// In Process struct (process.go):
|
||||||
|
testHandler http.Handler // set only in tests; bypasses subprocess and reverse proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
In `Process.Start()`, skip subprocess + health check when handler is set:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func (p *Process) start() error {
|
||||||
|
if p.testHandler != nil {
|
||||||
|
p.setState(StateReady)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// existing subprocess logic...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
In `Process.ProxyRequest()`, delegate directly to the handler:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Before the reverseProxy.ServeHTTP call:
|
||||||
|
if p.testHandler != nil {
|
||||||
|
p.testHandler.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**2b. Test helper to create the handler:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// newTestHandler returns an http.Handler that mimics llama.cpp's API
|
||||||
|
// (same endpoints as simple-responder).
|
||||||
|
func newTestHandler(respond string) http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { ... })
|
||||||
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { ... })
|
||||||
|
// ... other endpoints
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Tests for routing/auth/CORS/streaming then become:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestProxyManager_AuthRequired(t *testing.T) {
|
||||||
|
handler := newTestHandler("model1")
|
||||||
|
|
||||||
|
config := testConfigFromYAML(t, `
|
||||||
|
healthCheckTimeout: 15
|
||||||
|
logLevel: error
|
||||||
|
requiredAPIKeys: [test-key]
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: {{RESPONDER}} --port ${PORT} -silent -respond model1
|
||||||
|
`)
|
||||||
|
pm := NewProxyManager(config)
|
||||||
|
// inject handler — skips subprocess, health check, port allocation
|
||||||
|
pm.processGroups["model1"].process.testHandler = handler
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why this matters:** The handler is called directly in-process. No subprocess spawn, no health check timeout, no port allocation, no TCP roundtrip, no reverse proxy overhead. Routing tests go from ~100ms each (process startup + health check) to ~1ms. Unlike an `httptest.Server` approach, there are zero network hops.
|
||||||
|
|
||||||
|
**Why not blank-cmd + proxy URL:** A blank `cmd` with a `proxy` field pointing at `httptest.Server` still requires a real TCP roundtrip through the reverse proxy and introduces "external process" semantics to the config schema. Injecting the handler directly keeps it purely a test concern with no config changes.
|
||||||
|
|
||||||
|
**Scope:** Most tests in `proxymanager_test.go` (auth, CORS, model listing, streaming, peer proxy), `peerproxy_test.go`, `metrics_monitor_test.go`.
|
||||||
|
|
||||||
|
### Stage 3: Migrate tests incrementally
|
||||||
|
|
||||||
|
**Goal:** Convert existing tests to use the Stage 1 + Stage 2 helpers.
|
||||||
|
|
||||||
|
**Effort:** Medium | **Impact:** Cleaner, more reliable tests | **Risk:** None
|
||||||
|
|
||||||
|
Priority order:
|
||||||
|
1. `proxymanager_test.go` routing tests (highest count, most repetition)
|
||||||
|
2. `peerproxy_test.go` (straightforward, all HTTP routing)
|
||||||
|
3. `metrics_monitor_test.go` (capture logic doesn't need real processes)
|
||||||
|
4. `processgroup_test.go` swap tests (keep simple-responder for actual swap lifecycle tests)
|
||||||
|
|
||||||
|
Tests that **must keep simple-responder:**
|
||||||
|
- Process lifecycle: start/stop, SIGKILL, SIGTERM, TTL expiry, health check failures, failed start counting
|
||||||
|
- ProcessGroup swap concurrency (the port-collision test in `TestProcessGroup_ProxyRequestSwapIsTrueParallel`)
|
||||||
|
|
||||||
|
**Scope:** ~60-70% of tests can drop simple-responder.
|
||||||
|
|
||||||
|
### Stage 4 (optional): Process interface for ProcessGroup
|
||||||
|
|
||||||
|
**Goal:** Enable pure unit tests of ProcessGroup's swap/exclusive/concurrency logic without any HTTP server at all.
|
||||||
|
|
||||||
|
**Effort:** High | **Impact:** Pure unit tests possible | **Risk:** Medium (refactor core code)
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ProcessController interface {
|
||||||
|
Start() error
|
||||||
|
Stop(StopStrategy)
|
||||||
|
ProxyRequest(http.ResponseWriter, *http.Request) error
|
||||||
|
CurrentState() ProcessState
|
||||||
|
ID() string
|
||||||
|
SetState(ProcessState) // for test setup
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This requires:
|
||||||
|
- Extracting the interface
|
||||||
|
- A `MockProcess` implementation
|
||||||
|
- Refactoring `ProcessGroup` to use the interface instead of `*Process`
|
||||||
|
|
||||||
|
**Recommendation:** Only do this if ProcessGroup grows significantly more complex. Stages 1-3 give 80% of the benefit for 20% of the effort.
|
||||||
|
|
||||||
|
## Effort/Impact Summary
|
||||||
|
|
||||||
|
| Stage | Effort | Impact | Risk |
|
||||||
|
|-------|--------|--------|------|
|
||||||
|
| 1. YAML config helper | Low | Config bugs caught earlier | None |
|
||||||
|
| 2. Injected test handler | Medium | 10-100x faster routing tests | Low |
|
||||||
|
| 3. Migrate tests | Medium | Cleaner, more reliable tests | None |
|
||||||
|
| 4. Process interface | High | Pure unit tests possible | Medium |
|
||||||
|
|
||||||
|
**Recommended approach:** Do stages 1-3 in order. Each stage is independently valuable and can ship on its own. Stage 4 is deferred unless there's a specific need.
|
||||||
@@ -42,6 +42,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
build-essential cmake git python3 python3-pip libssl-dev \
|
build-essential cmake git python3 python3-pip libssl-dev \
|
||||||
curl ca-certificates ccache make wget software-properties-common \
|
curl ca-certificates ccache make wget software-properties-common \
|
||||||
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
|
libvulkan-dev glslang-tools spirv-tools vulkan-validationlayers glslc \
|
||||||
|
spirv-headers \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
@@ -148,7 +149,7 @@ ARG IK_LLAMA_COMMIT_HASH=unknown
|
|||||||
ARG RUN_UID=0
|
ARG RUN_UID=0
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python3-numpy python3-sentencepiece \
|
python3-numpy python3-sentencepiece python3-pip \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Create non-root user when RUN_UID != 0
|
# Create non-root user when RUN_UID != 0
|
||||||
@@ -179,6 +180,9 @@ COPY --from=llama-build /install/bin/llama-cli /usr/local/bin/
|
|||||||
# Copy ik-llama-server (CUDA only; empty copy for vulkan)
|
# Copy ik-llama-server (CUDA only; empty copy for vulkan)
|
||||||
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
|
COPY --from=ik-llama-build /install/bin/ /usr/local/bin/
|
||||||
|
|
||||||
|
# Install uv
|
||||||
|
RUN pip install uv --break-system-packages
|
||||||
|
|
||||||
# Copy llama-swap binary
|
# Copy llama-swap binary
|
||||||
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
|
COPY --from=llama-swap-download /install/bin/llama-swap /usr/local/bin/
|
||||||
COPY --from=llama-swap-download /install/llama-swap-version /tmp/
|
COPY --from=llama-swap-download /install/llama-swap-version /tmp/
|
||||||
|
|||||||
@@ -38,8 +38,16 @@ if [ "$VERSION" = "latest" ]; then
|
|||||||
echo "Latest version: ${VERSION}"
|
echo "Latest version: ${VERSION}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) ARCH="amd64" ;;
|
||||||
|
aarch64|arm64) ARCH="arm64" ;;
|
||||||
|
*) echo "FATAL: Unsupported architecture: $ARCH" >&2; exit 1 ;;
|
||||||
|
esac
|
||||||
|
|
||||||
# Download and extract
|
# Download and extract
|
||||||
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_amd64.tar.gz"
|
URL="https://github.com/${REPO}/releases/download/v${VERSION}/llama-swap_${VERSION}_linux_${ARCH}.tar.gz"
|
||||||
echo "=== Downloading llama-swap v${VERSION} ==="
|
echo "=== Downloading llama-swap v${VERSION} ==="
|
||||||
echo "URL: $URL"
|
echo "URL: $URL"
|
||||||
curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
|
curl -fSL -o /tmp/llama-swap.tar.gz "$URL"
|
||||||
@@ -56,4 +64,4 @@ fi
|
|||||||
echo "$VERSION" > /install/llama-swap-version
|
echo "$VERSION" > /install/llama-swap-version
|
||||||
|
|
||||||
echo "=== llama-swap v${VERSION} installed ==="
|
echo "=== llama-swap v${VERSION} installed ==="
|
||||||
ls -la /install/bin/llama-swap
|
ls -la /install/bin/llama-swap
|
||||||
@@ -4,8 +4,8 @@ go 1.26.1
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/billziss-gh/golib v0.2.0
|
github.com/billziss-gh/golib v0.2.0
|
||||||
github.com/fsnotify/fsnotify v1.9.0
|
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
|
github.com/klauspost/compress v1.18.5
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
|
||||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
@@ -34,6 +32,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||||
|
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||||
|
|||||||
+46
-40
@@ -9,14 +9,15 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy"
|
"github.com/mostlygeek/llama-swap/proxy"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/mostlygeek/llama-swap/proxy/configwatcher"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -79,6 +80,17 @@ func main() {
|
|||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Reload signals (SIGHUP on POSIX, none on Windows — Windows does not
|
||||||
|
// deliver SIGHUP). Always wired up so `kill -HUP` works regardless of
|
||||||
|
// --watch-config.
|
||||||
|
reloadChan := make(chan os.Signal, 1)
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
signal.Notify(reloadChan, syscall.SIGHUP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context that bounds the lifetime of background watcher goroutines.
|
||||||
|
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
// Create server with initial handler
|
// Create server with initial handler
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: *listenStr,
|
Addr: *listenStr,
|
||||||
@@ -121,52 +133,45 @@ func main() {
|
|||||||
// load the initial proxy manager
|
// load the initial proxy manager
|
||||||
reloadProxyManager()
|
reloadProxyManager()
|
||||||
debouncedReload := debounce(time.Second, reloadProxyManager)
|
debouncedReload := debounce(time.Second, reloadProxyManager)
|
||||||
if *watchConfig {
|
|
||||||
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
|
||||||
if e.ReloadingState == proxy.ReloadingStateStart {
|
|
||||||
debouncedReload()
|
|
||||||
}
|
|
||||||
})()
|
|
||||||
|
|
||||||
fmt.Println("Watching Configuration for changes")
|
// Listen for ConfigFileChangedEvent unconditionally so SIGHUP and the
|
||||||
|
// poll-based watcher both feed the same debounced reload pipeline. The
|
||||||
|
// UI also listens for the matching ReloadingStateEnd emitted from
|
||||||
|
// reloadProxyManager.
|
||||||
|
defer event.On(func(e proxy.ConfigFileChangedEvent) {
|
||||||
|
if e.ReloadingState == proxy.ReloadingStateStart {
|
||||||
|
debouncedReload()
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
// SIGHUP (or platform-equivalent) → reload. Back-to-back signals collapse
|
||||||
|
// to one reload via the debounce window, which is the desired behavior.
|
||||||
|
go func() {
|
||||||
|
for range reloadChan {
|
||||||
|
fmt.Println("Received reload signal, reloading configuration")
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if *watchConfig {
|
||||||
go func() {
|
go func() {
|
||||||
absConfigPath, err := filepath.Abs(*configPath)
|
absConfigPath, err := filepath.Abs(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
fmt.Printf("Error getting absolute path for watching config file: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
watcher, err := fsnotify.NewWatcher()
|
fmt.Println("Watching configuration for changes (poll-based, 2s interval)")
|
||||||
if err != nil {
|
(&configwatcher.Watcher{
|
||||||
fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err)
|
Path: absConfigPath,
|
||||||
return
|
Interval: configwatcher.DefaultInterval,
|
||||||
}
|
OnChange: func() {
|
||||||
|
event.Emit(proxy.ConfigFileChangedEvent{
|
||||||
configDir := filepath.Dir(absConfigPath)
|
ReloadingState: proxy.ReloadingStateStart,
|
||||||
err = watcher.Add(configDir)
|
})
|
||||||
if err != nil {
|
},
|
||||||
fmt.Printf("Error adding config path directory (%s) to watcher: %v. File watching disabled.", configDir, err)
|
}).Run(watcherCtx)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer watcher.Close()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case changeEvent := <-watcher.Events:
|
|
||||||
if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) {
|
|
||||||
event.Emit(proxy.ConfigFileChangedEvent{
|
|
||||||
ReloadingState: proxy.ReloadingStateStart,
|
|
||||||
})
|
|
||||||
} else if changeEvent.Name == filepath.Join(configDir, "..data") && changeEvent.Has(fsnotify.Create) {
|
|
||||||
// the change for k8s configmap
|
|
||||||
event.Emit(proxy.ConfigFileChangedEvent{
|
|
||||||
ReloadingState: proxy.ReloadingStateStart,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
case err := <-watcher.Errors:
|
|
||||||
log.Printf("File watcher error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,6 +179,7 @@ func main() {
|
|||||||
go func() {
|
go func() {
|
||||||
sig := <-sigChan
|
sig := <-sigChan
|
||||||
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
fmt.Printf("Received signal %v, shutting down...\n", sig)
|
||||||
|
watcherCancel()
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
// Package configwatcher provides a simple cross-platform file watcher based
|
||||||
|
// on os.Stat polling. It works correctly inside Docker containers where the
|
||||||
|
// config file is bind-mounted as an individual file, and for k8s ConfigMap
|
||||||
|
// projections (which present the file as a symlink to an atomically swapped
|
||||||
|
// target) — both cases where inotify-based watchers are unreliable.
|
||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultInterval = 2 * time.Second
|
||||||
|
|
||||||
|
type Watcher struct {
|
||||||
|
Path string
|
||||||
|
Interval time.Duration
|
||||||
|
OnChange func()
|
||||||
|
}
|
||||||
|
|
||||||
|
type snapshot struct {
|
||||||
|
exists bool
|
||||||
|
modTime time.Time
|
||||||
|
size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||||
|
// OnChange whenever the file's modification time or size changes, or when
|
||||||
|
// the file reappears after being missing. The baseline poll establishes
|
||||||
|
// initial state and does not fire OnChange.
|
||||||
|
func (w *Watcher) Run(ctx context.Context) {
|
||||||
|
interval := w.Interval
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = DefaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
prev := stat(w.Path)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cur := stat(w.Path)
|
||||||
|
if changed(prev, cur) && w.OnChange != nil {
|
||||||
|
w.OnChange()
|
||||||
|
}
|
||||||
|
prev = cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stat(path string) snapshot {
|
||||||
|
fi, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
log.Printf("configwatcher: stat %s: %v", path, err)
|
||||||
|
}
|
||||||
|
return snapshot{}
|
||||||
|
}
|
||||||
|
return snapshot{
|
||||||
|
exists: true,
|
||||||
|
modTime: fi.ModTime(),
|
||||||
|
size: fi.Size(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func changed(prev, cur snapshot) bool {
|
||||||
|
// Present → missing: stay quiet (likely a transient rename-style write).
|
||||||
|
// Missing → present: fire so we reload as soon as the file comes back.
|
||||||
|
if !cur.exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !prev.exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !prev.modTime.Equal(cur.modTime) || prev.size != cur.size
|
||||||
|
}
|
||||||
@@ -0,0 +1,191 @@
|
|||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testInterval = 25 * time.Millisecond
|
||||||
|
|
||||||
|
// startWatcher launches w.Run in a goroutine and returns a function that
|
||||||
|
// cancels the context and waits for Run to return.
|
||||||
|
func startWatcher(t *testing.T, w *Watcher) func() {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.Run(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("watcher did not stop within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForCount blocks until counter reaches want or timeout elapses.
|
||||||
|
func waitForCount(t *testing.T, counter *int64, want int64, timeout time.Duration) bool {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if atomic.LoadInt64(counter) >= want {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_NoFireOnBaseline(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: path,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 5)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_DetectsModTimeChange(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
// Force a known baseline mtime.
|
||||||
|
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||||
|
require.NoError(t, os.Chtimes(path, base, base))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: path,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
// Let the baseline settle.
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Bump mtime well above the baseline so low-resolution filesystems still notice.
|
||||||
|
require.NoError(t, os.Chtimes(path, base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||||
|
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_DetectsSizeChangeWithSameModTime(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
fi, err := os.Stat(path)
|
||||||
|
require.NoError(t, err)
|
||||||
|
originalMtime := fi.ModTime()
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: path,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("aaaaa"), 0o644))
|
||||||
|
// Reset mtime back to the original so size is the only signal.
|
||||||
|
require.NoError(t, os.Chtimes(path, originalMtime, originalMtime))
|
||||||
|
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire on size change")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_SymlinkTargetSwap(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
targetA := filepath.Join(dir, "targetA")
|
||||||
|
targetB := filepath.Join(dir, "targetB")
|
||||||
|
link := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(targetA, []byte("AAAA"), 0o644))
|
||||||
|
require.NoError(t, os.WriteFile(targetB, []byte("BBBBBBBB"), 0o644))
|
||||||
|
|
||||||
|
if err := os.Symlink(targetA, link); err != nil {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skipf("symlink creation requires privilege on Windows: %v", err)
|
||||||
|
}
|
||||||
|
t.Fatalf("os.Symlink: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: link,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Atomic symlink swap (k8s ConfigMap pattern): create new symlink at a
|
||||||
|
// temp name, then rename over the existing one.
|
||||||
|
tmpLink := filepath.Join(dir, "config.yaml.tmp")
|
||||||
|
require.NoError(t, os.Symlink(targetB, tmpLink))
|
||||||
|
require.NoError(t, os.Rename(tmpLink, link))
|
||||||
|
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after symlink target swap")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_FileMissingThenReturns(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startWatcher(t, &Watcher{
|
||||||
|
Path: path,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.Remove(path))
|
||||||
|
time.Sleep(testInterval * 3)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "removal alone must not fire")
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("b"), 0o644))
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when file returns")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(path, []byte("a"), 0o644))
|
||||||
|
|
||||||
|
w := &Watcher{Path: path, Interval: testInterval}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() { w.Run(ctx); close(done) }()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Run did not return within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,15 +1,22 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,6 +73,16 @@ func getTestPort() int {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testConfigFromYAML substitutes {{RESPONDER}} with the simple-responder path and
|
||||||
|
// loads through the real config pipeline (env vars, macros, port assignment, etc.)
|
||||||
|
func testConfigFromYAML(t *testing.T, yamlTmpl string) config.Config {
|
||||||
|
t.Helper()
|
||||||
|
yamlStr := strings.ReplaceAll(yamlTmpl, "{{RESPONDER}}", filepath.ToSlash(simpleResponderPath))
|
||||||
|
cfg, err := config.LoadConfigFromReader(strings.NewReader(yamlStr))
|
||||||
|
require.NoError(t, err)
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
|
||||||
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
|
||||||
}
|
}
|
||||||
@@ -88,3 +105,188 @@ proxy: "http://127.0.0.1:%d"
|
|||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// injectTestHandlers sets a testHandler on every Process in every ProcessGroup
|
||||||
|
// of the given ProxyManager, bypassing subprocess launches. modelResponses maps
|
||||||
|
// model IDs to their respond strings; if a model ID is not in the map, the model
|
||||||
|
// ID itself is used.
|
||||||
|
func injectTestHandlers(pm *ProxyManager, modelResponses map[string]string) {
|
||||||
|
for _, pg := range pm.processGroups {
|
||||||
|
for modelID, process := range pg.processes {
|
||||||
|
respond := modelID
|
||||||
|
if r, ok := modelResponses[modelID]; ok {
|
||||||
|
respond = r
|
||||||
|
}
|
||||||
|
process.testHandler = newTestHandler(respond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestHandler returns an http.Handler that mimics simple-responder's API.
|
||||||
|
// It supports the endpoints that routing tests depend on, without launching
|
||||||
|
// any subprocess or binding any port.
|
||||||
|
func newTestHandler(respond string) http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
bodyBytes, _ := io.ReadAll(r.Body)
|
||||||
|
isStreaming := r.URL.Query().Get("stream") == "true"
|
||||||
|
|
||||||
|
if wait := r.URL.Query().Get("wait"); wait != "" {
|
||||||
|
if d, err := time.ParseDuration(wait); err == nil {
|
||||||
|
time.Sleep(d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isStreaming {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
flusher := w.(http.Flusher)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
data, _ := json.Marshal(map[string]any{
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"index": 0, "delta": map[string]any{"content": "asdf"}, "finish_reason": nil},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
fmt.Fprintf(w, "event: message\ndata: %s\n\n", data)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
finalData, _ := json.Marshal(map[string]any{
|
||||||
|
"usage": map[string]any{
|
||||||
|
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||||
|
},
|
||||||
|
"timings": map[string]any{
|
||||||
|
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||||
|
"predicted_ms": 17, "predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
fmt.Fprintf(w, "event: message\ndata: %s\n\n", finalData)
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "event: message\ndata: [DONE]\n\n")
|
||||||
|
flusher.Flush()
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"responseMessage": respond,
|
||||||
|
"h_content_length": r.Header.Get("Content-Length"),
|
||||||
|
"request_body": string(bodyBytes),
|
||||||
|
"usage": map[string]any{
|
||||||
|
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||||
|
},
|
||||||
|
"timings": map[string]any{
|
||||||
|
"prompt_n": 25, "prompt_ms": 13, "predicted_n": 10,
|
||||||
|
"predicted_ms": 17, "predicted_per_second": 10,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
if modelName != respond {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, respond)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"message": "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"responseMessage": respond,
|
||||||
|
"usage": map[string]any{
|
||||||
|
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/completion", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"responseMessage": respond,
|
||||||
|
"usage": map[string]any{
|
||||||
|
"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/audio/transcriptions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model := r.FormValue("model")
|
||||||
|
if model == "" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": "Missing model parameter"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, _, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fileBytes, _ := io.ReadAll(file)
|
||||||
|
file.Close()
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"text": fmt.Sprintf("The length of the file is %d bytes", len(fileBytes)),
|
||||||
|
"model": model,
|
||||||
|
"h_content_type": r.Header.Get("Content-Type"),
|
||||||
|
"h_content_length": r.Header.Get("Content-Length"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/v1/audio/voices", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
model := r.URL.Query().Get("model")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"voices": []string{"voice1"}, "model": model,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
fmt.Fprint(w, respond)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/sdapi/v1/txt2img", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"model": modelName, "images": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/sdapi/v1/img2img", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"model": modelName, "images": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/sdapi/v1/loras", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"loras": []string{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|||||||
+33
-2
@@ -147,6 +147,20 @@ type Matrix struct {
|
|||||||
config config.Config
|
config config.Config
|
||||||
proxyLogger *LogMonitor
|
proxyLogger *LogMonitor
|
||||||
upstreamLogger *LogMonitor
|
upstreamLogger *LogMonitor
|
||||||
|
|
||||||
|
// inflight tracks ProxyRequest calls that have released m.Lock but may
|
||||||
|
// not yet have incremented Process.inFlightRequests. A concurrent
|
||||||
|
// request that needs to evict models waits for inflight to drain under
|
||||||
|
// m.Lock before stopping anything. Without this, a request that
|
||||||
|
// released m.Lock but has not yet reached Process.inFlightRequests.Add(1)
|
||||||
|
// races with Stop()'s Wait() and can be killed mid-request.
|
||||||
|
inflight sync.WaitGroup
|
||||||
|
|
||||||
|
// testDelayFastPath is a test-only hook invoked in the no-eviction path
|
||||||
|
// after m.Lock is released but before the request is dispatched to
|
||||||
|
// Process.ProxyRequest. Tests use it to park a request at the exact
|
||||||
|
// race window to deterministically reproduce the race.
|
||||||
|
testDelayFastPath func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMatrix creates a Matrix from config. It creates a Process for every
|
// NewMatrix creates a Matrix from config. It creates a Process for every
|
||||||
@@ -197,6 +211,13 @@ func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Req
|
|||||||
|
|
||||||
// Evict models that need to be stopped
|
// Evict models that need to be stopped
|
||||||
if len(result.Evict) > 0 {
|
if len(result.Evict) > 0 {
|
||||||
|
// Wait for any in-flight ProxyRequest calls to register on their
|
||||||
|
// Process before stopping anything. Without this, a request that
|
||||||
|
// released m.Lock but has not yet incremented
|
||||||
|
// Process.inFlightRequests races with Stop() and can be killed
|
||||||
|
// mid-request.
|
||||||
|
m.inflight.Wait()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, evictModel := range result.Evict {
|
for _, evictModel := range result.Evict {
|
||||||
if p, exists := m.processes[evictModel]; exists {
|
if p, exists := m.processes[evictModel]; exists {
|
||||||
@@ -209,8 +230,18 @@ func (m *Matrix) ProxyRequest(modelID string, w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register this request in inflight before releasing m.Lock so a
|
||||||
|
// concurrent eviction will wait for it to complete.
|
||||||
|
m.inflight.Add(1)
|
||||||
|
defer m.inflight.Done()
|
||||||
|
isFastPath := len(result.Evict) == 0
|
||||||
m.Unlock()
|
m.Unlock()
|
||||||
|
|
||||||
|
if isFastPath && m.testDelayFastPath != nil {
|
||||||
|
m.testDelayFastPath()
|
||||||
|
}
|
||||||
|
|
||||||
// Proxy the request (Process handles on-demand start)
|
// Proxy the request (Process handles on-demand start)
|
||||||
process.ProxyRequest(w, r)
|
process.ProxyRequest(w, r)
|
||||||
return nil
|
return nil
|
||||||
@@ -266,7 +297,7 @@ func (m *Matrix) Shutdown() {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunningModels returns model names currently in StateReady.
|
// RunningModels returns model names currently in an active (non-stopped) state.
|
||||||
func (m *Matrix) RunningModels() []string {
|
func (m *Matrix) RunningModels() []string {
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
@@ -277,7 +308,7 @@ func (m *Matrix) RunningModels() []string {
|
|||||||
func (m *Matrix) runningModels() []string {
|
func (m *Matrix) runningModels() []string {
|
||||||
var running []string
|
var running []string
|
||||||
for id, process := range m.processes {
|
for id, process := range m.processes {
|
||||||
if process.CurrentState() == StateReady {
|
if process.CurrentState() != StateStopped && process.CurrentState() != StateShutdown {
|
||||||
running = append(running, id)
|
running = append(running, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -169,6 +173,124 @@ func TestMatrixSolver_NothingRunning(t *testing.T) {
|
|||||||
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
assert.Equal(t, []string{"g", "v"}, result.TargetSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestMatrix_ProxyRequestSwapRaceAgainstFastPath verifies that an eviction
|
||||||
|
// cannot stop a process while an in-flight ProxyRequest for that process is
|
||||||
|
// still in the [m.Unlock, Process.inFlightRequests.Add(1)] window. Without
|
||||||
|
// matrix-level inflight tracking, the eviction's Stop() races with the
|
||||||
|
// pending request and kills it mid-start.
|
||||||
|
func TestMatrix_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||||
|
cfg := config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
ExpandedSets: []config.ExpandedSet{
|
||||||
|
{SetName: "s1", Models: []string{"model1"}},
|
||||||
|
{SetName: "s2", Models: []string{"model2"}},
|
||||||
|
},
|
||||||
|
Matrix: &config.MatrixConfig{},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewMatrix(cfg, testLogger, testLogger)
|
||||||
|
defer m.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
// Bypass real subprocesses so the test is fast and deterministic.
|
||||||
|
m.processes["model1"].testHandler = newTestHandler("model1")
|
||||||
|
m.processes["model2"].testHandler = newTestHandler("model2")
|
||||||
|
|
||||||
|
// Prime: run a request through model1 so it reaches StateReady and
|
||||||
|
// subsequent requests take the no-eviction path.
|
||||||
|
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
primeW := httptest.NewRecorder()
|
||||||
|
require.NoError(t, m.ProxyRequest("model1", primeW, primeReq))
|
||||||
|
require.Equal(t, http.StatusOK, primeW.Code)
|
||||||
|
require.Equal(t, StateReady, m.processes["model1"].CurrentState())
|
||||||
|
require.Equal(t, StateStopped, m.processes["model2"].CurrentState())
|
||||||
|
|
||||||
|
// Install fast-path hook that signals arrival and waits for release.
|
||||||
|
// This parks R2 at the race window — after m.Lock is released but
|
||||||
|
// before Process.inFlightRequests.Add(1).
|
||||||
|
r2Reached := make(chan struct{})
|
||||||
|
r2Release := make(chan struct{})
|
||||||
|
m.testDelayFastPath = func() {
|
||||||
|
close(r2Reached)
|
||||||
|
<-r2Release
|
||||||
|
}
|
||||||
|
|
||||||
|
// R2: no-eviction request for model1. Will pause at the hook.
|
||||||
|
r2Done := make(chan struct{})
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r2Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, m.ProxyRequest("model1", w2, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Deterministically wait for R2 to reach the race window.
|
||||||
|
<-r2Reached
|
||||||
|
|
||||||
|
// R3: request for model2 which requires evicting model1. Must wait for
|
||||||
|
// R2 to finish before touching model1.
|
||||||
|
r3Done := make(chan struct{})
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r3Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, m.ProxyRequest("model2", w3, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Spin until R3 has acquired m.Lock and entered the eviction path. In
|
||||||
|
// the fixed code, R3 then blocks on m.inflight.Wait() while still
|
||||||
|
// holding the lock, so TryLock keeps failing.
|
||||||
|
for m.TryLock() {
|
||||||
|
m.Unlock()
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||||
|
// state. In the fixed code R3 is blocked and nothing changes; in the
|
||||||
|
// buggy code R3 will Stop() model1 and start model2 within microseconds.
|
||||||
|
deadline := time.Now().Add(100 * time.Millisecond)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if m.processes["model1"].CurrentState() != StateReady ||
|
||||||
|
m.processes["model2"].CurrentState() != StateStopped {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
done := false
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
done = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
t.Fatal("eviction completed while in-flight request was still pending — race not prevented")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.Equal(t, StateReady, m.processes["model1"].CurrentState(),
|
||||||
|
"model1 must stay Ready while an in-flight request is pending")
|
||||||
|
assert.Equal(t, StateStopped, m.processes["model2"].CurrentState(),
|
||||||
|
"model2 must not be started until R2 finishes and model1 is evicted")
|
||||||
|
|
||||||
|
// Release R2 and let both requests finish.
|
||||||
|
close(r2Release)
|
||||||
|
<-r2Done
|
||||||
|
<-r3Done
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w2.Code)
|
||||||
|
assert.Contains(t, w2.Body.String(), "model1")
|
||||||
|
assert.Equal(t, http.StatusOK, w3.Code)
|
||||||
|
assert.Contains(t, w3.Body.String(), "model2")
|
||||||
|
}
|
||||||
|
|
||||||
func TestMatrixSolver_FullScenario(t *testing.T) {
|
func TestMatrixSolver_FullScenario(t *testing.T) {
|
||||||
// Simulates the example config:
|
// Simulates the example config:
|
||||||
// standard: [g,v], [q,v], [m,v]
|
// standard: [g,v], [q,v], [m,v]
|
||||||
|
|||||||
+101
-34
@@ -13,10 +13,54 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// zstdEncOptions are the shared zstd encoder options for maximum compression.
|
||||||
|
var zstdEncOptions = []zstd.EOption{
|
||||||
|
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||||
|
}
|
||||||
|
|
||||||
|
// zstdDecOptions are the shared zstd decoder options.
|
||||||
|
var zstdDecOptions = []zstd.DOption{}
|
||||||
|
|
||||||
|
// zstdEncPool pools zstd.Encoder instances to reduce allocations.
|
||||||
|
var zstdEncPool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
enc, _ := zstd.NewWriter(nil, zstdEncOptions...)
|
||||||
|
return enc
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// zstdDecPool pools zstd.Decoder instances to reduce allocations.
|
||||||
|
var zstdDecPool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
dec, _ := zstd.NewReader(nil, zstdDecOptions...)
|
||||||
|
return dec
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressCapture marshals a ReqRespCapture to JSON and compresses it with zstd.
|
||||||
|
// Returns compressed bytes and the original JSON byte count for logging.
|
||||||
|
func compressCapture(c *ReqRespCapture) ([]byte, int, error) {
|
||||||
|
jsonBytes, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("marshal capture: %w", err)
|
||||||
|
}
|
||||||
|
enc := zstdEncPool.Get().(*zstd.Encoder)
|
||||||
|
defer zstdEncPool.Put(enc)
|
||||||
|
return enc.EncodeAll(jsonBytes, nil), len(jsonBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompressCapture decompresses zstd-compressed JSON and returns it.
|
||||||
|
func decompressCapture(data []byte) ([]byte, error) {
|
||||||
|
dec := zstdDecPool.Get().(*zstd.Decoder)
|
||||||
|
defer zstdDecPool.Put(dec)
|
||||||
|
return dec.DecodeAll(data, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||||
type TokenMetrics struct {
|
type TokenMetrics struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
@@ -40,18 +84,6 @@ type ReqRespCapture struct {
|
|||||||
RespBody []byte `json:"resp_body"`
|
RespBody []byte `json:"resp_body"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size returns the approximate memory usage of this capture in bytes
|
|
||||||
func (c *ReqRespCapture) Size() int {
|
|
||||||
size := len(c.ReqPath) + len(c.ReqBody) + len(c.RespBody)
|
|
||||||
for k, v := range c.ReqHeaders {
|
|
||||||
size += len(k) + len(v)
|
|
||||||
}
|
|
||||||
for k, v := range c.RespHeaders {
|
|
||||||
size += len(k) + len(v)
|
|
||||||
}
|
|
||||||
return size
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenMetricsEvent represents a token metrics event
|
// TokenMetricsEvent represents a token metrics event
|
||||||
type TokenMetricsEvent struct {
|
type TokenMetricsEvent struct {
|
||||||
Metrics TokenMetrics
|
Metrics TokenMetrics
|
||||||
@@ -71,10 +103,10 @@ type metricsMonitor struct {
|
|||||||
|
|
||||||
// capture fields
|
// capture fields
|
||||||
enableCaptures bool
|
enableCaptures bool
|
||||||
captures map[int]ReqRespCapture // map for O(1) lookup by ID
|
captures map[int][]byte // zstd-compressed JSON of ReqRespCapture
|
||||||
captureOrder []int // track insertion order for FIFO eviction
|
captureOrder []int // track insertion order for FIFO eviction
|
||||||
captureSize int // current total size in bytes
|
captureSize int // current total compressed size in bytes
|
||||||
maxCaptureSize int // max bytes for captures
|
maxCaptureSize int // max bytes for captures (uncompressed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
// newMetricsMonitor creates a new metricsMonitor. captureBufferMB is the
|
||||||
@@ -84,7 +116,7 @@ func newMetricsMonitor(logger *LogMonitor, maxMetrics int, captureBufferMB int)
|
|||||||
logger: logger,
|
logger: logger,
|
||||||
maxMetrics: maxMetrics,
|
maxMetrics: maxMetrics,
|
||||||
enableCaptures: captureBufferMB > 0,
|
enableCaptures: captureBufferMB > 0,
|
||||||
captures: make(map[int]ReqRespCapture),
|
captures: make(map[int][]byte),
|
||||||
captureOrder: make([]int, 0),
|
captureOrder: make([]int, 0),
|
||||||
captureSize: 0,
|
captureSize: 0,
|
||||||
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
maxCaptureSize: captureBufferMB * 1024 * 1024,
|
||||||
@@ -108,45 +140,80 @@ func (mp *metricsMonitor) addMetrics(metric TokenMetrics) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addCapture adds a new capture to the buffer with size-based eviction.
|
// addCapture adds a new capture to the buffer with size-based eviction.
|
||||||
// Captures are skipped if enableCaptures is false or if capture exceeds maxCaptureSize.
|
// Captures are skipped if enableCaptures is false or if compressed data exceeds maxCaptureSize.
|
||||||
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
func (mp *metricsMonitor) addCapture(capture ReqRespCapture) {
|
||||||
if !mp.enableCaptures {
|
if !mp.enableCaptures {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mp.mu.Lock()
|
compressed, uncompressedBytes, err := compressCapture(&capture)
|
||||||
defer mp.mu.Unlock()
|
if err != nil {
|
||||||
|
mp.logger.Warnf("failed to compress capture: %v, skipping", err)
|
||||||
captureSize := capture.Size()
|
|
||||||
if captureSize > mp.maxCaptureSize {
|
|
||||||
mp.logger.Warnf("capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evict oldest (FIFO) until room available
|
captureSize := len(compressed)
|
||||||
|
if captureSize > mp.maxCaptureSize {
|
||||||
|
mp.logger.Warnf("compressed capture size %d exceeds max %d, skipping", captureSize, mp.maxCaptureSize)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
compressionRatio := (1 - float64(captureSize)/float64(uncompressedBytes)) * 100
|
||||||
|
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Evict oldest (FIFO) until room available for the compressed data
|
||||||
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
for mp.captureSize+captureSize > mp.maxCaptureSize && len(mp.captureOrder) > 0 {
|
||||||
oldestID := mp.captureOrder[0]
|
oldestID := mp.captureOrder[0]
|
||||||
mp.captureOrder = mp.captureOrder[1:]
|
mp.captureOrder = mp.captureOrder[1:]
|
||||||
if evicted, exists := mp.captures[oldestID]; exists {
|
if evicted, exists := mp.captures[oldestID]; exists {
|
||||||
mp.captureSize -= evicted.Size()
|
l := len(evicted)
|
||||||
|
mp.captureSize -= l
|
||||||
delete(mp.captures, oldestID)
|
delete(mp.captures, oldestID)
|
||||||
|
mp.logger.Debugf("Capture %d evicted to make space: %d bytes", oldestID, l)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mp.captures[capture.ID] = capture
|
mp.captures[capture.ID] = compressed
|
||||||
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
mp.captureOrder = append(mp.captureOrder, capture.ID)
|
||||||
mp.captureSize += captureSize
|
mp.captureSize += captureSize
|
||||||
|
|
||||||
|
mp.logger.Debugf("Capture %d compressed and saved: %d bytes -> %d bytes (%.1f%% compression)", capture.ID, uncompressedBytes, len(compressed), compressionRatio)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCaptureByID returns a capture by its ID, or nil if not found.
|
// getCompressedBytes returns the raw compressed bytes for a capture by ID.
|
||||||
func (mp *metricsMonitor) getCaptureByID(id int) *ReqRespCapture {
|
func (mp *metricsMonitor) getCompressedBytes(id int) ([]byte, bool) {
|
||||||
mp.mu.RLock()
|
mp.mu.RLock()
|
||||||
defer mp.mu.RUnlock()
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
if capture, exists := mp.captures[id]; exists {
|
data, exists := mp.captures[id]
|
||||||
return &capture
|
return data, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCaptureByID returns decompressed capture bytes if found and decompress=true.
|
||||||
|
// If decompress=false, returns the raw zstd-compressed bytes.
|
||||||
|
// Returns nil if the capture is not found.
|
||||||
|
func (mp *metricsMonitor) getCaptureByID(id int, decompress bool) []byte {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
data, exists := mp.captures[id]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
if !decompress {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
decompressed, err := decompressCapture(data)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("failed to decompress capture %d: %v", id, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return decompressed
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMetrics returns a copy of the current metrics
|
// getMetrics returns a copy of the current metrics
|
||||||
@@ -290,8 +357,8 @@ func (mp *metricsMonitor) wrapHandler(
|
|||||||
RespHeaders: respHeaders,
|
RespHeaders: respHeaders,
|
||||||
RespBody: body,
|
RespBody: body,
|
||||||
}
|
}
|
||||||
// Only set HasCapture if the capture will actually be stored (not too large)
|
compressed, _, err := compressCapture(capture)
|
||||||
if capture.Size() <= mp.maxCaptureSize {
|
if err == nil && len(compressed) <= mp.maxCaptureSize {
|
||||||
tm.HasCapture = true
|
tm.HasCapture = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"compress/flate"
|
"compress/flate"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -953,28 +954,27 @@ func TestMetricsMonitor_WrapHandler_Compression(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReqRespCapture_Size(t *testing.T) {
|
func TestReqRespCapture_CompressedSize(t *testing.T) {
|
||||||
t.Run("calculates size correctly", func(t *testing.T) {
|
t.Run("compressed size is smaller than uncompressed", func(t *testing.T) {
|
||||||
capture := ReqRespCapture{
|
capture := ReqRespCapture{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
ReqPath: "/v1/chat/completions", // 20 bytes
|
ReqPath: "/v1/chat/completions",
|
||||||
ReqHeaders: map[string]string{
|
ReqBody: []byte(`{"model":"test","prompt":"hello world this is a test request body that is reasonably long"}`),
|
||||||
"Content-Type": "application/json", // 12 + 16 = 28
|
RespBody: []byte(`{"id":"resp-123","object":"chat.completion","created":1234567890,"model":"test-model","choices":[{"index":0,"message":{"role":"assistant","content":"This is a test response body with some meaningful content to compress"}},{"index":1,"message":{"role":"user","content":"Another message here"}}]}`),
|
||||||
},
|
|
||||||
ReqBody: []byte("request body"), // 12 bytes
|
|
||||||
RespHeaders: map[string]string{
|
|
||||||
"X-Test": "value", // 6 + 5 = 11
|
|
||||||
},
|
|
||||||
RespBody: []byte("response body"), // 13 bytes
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expected: 20 + 12 + 13 + 28 + 11 = 84
|
compressed, uncompressed, err := compressCapture(&capture)
|
||||||
assert.Equal(t, 84, capture.Size())
|
assert.NoError(t, err)
|
||||||
|
assert.Greater(t, uncompressed, 0)
|
||||||
|
assert.True(t, len(compressed) < uncompressed, "compressed (%d bytes) should be smaller than uncompressed JSON (%d bytes)", len(compressed), uncompressed)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles empty capture", func(t *testing.T) {
|
t.Run("empty capture produces compressed output", func(t *testing.T) {
|
||||||
capture := ReqRespCapture{}
|
capture := ReqRespCapture{}
|
||||||
assert.Equal(t, 0, capture.Size())
|
compressed, _, err := compressCapture(&capture)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, compressed)
|
||||||
|
assert.True(t, len(compressed) > 0)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -989,7 +989,7 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
|||||||
mm.addCapture(capture)
|
mm.addCapture(capture)
|
||||||
|
|
||||||
// Should not store capture
|
// Should not store capture
|
||||||
assert.Nil(t, mm.getCaptureByID(0))
|
assert.Nil(t, mm.getCaptureByID(0, false))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("adds capture when enabled", func(t *testing.T) {
|
t.Run("adds capture when enabled", func(t *testing.T) {
|
||||||
@@ -1002,41 +1002,55 @@ func TestMetricsMonitor_AddCapture(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mm.addCapture(capture)
|
mm.addCapture(capture)
|
||||||
|
|
||||||
retrieved := mm.getCaptureByID(0)
|
retrieved := mm.getCaptureByID(0, true)
|
||||||
assert.NotNil(t, retrieved)
|
assert.NotNil(t, retrieved)
|
||||||
assert.Equal(t, 0, retrieved.ID)
|
|
||||||
assert.Equal(t, []byte("test request"), retrieved.ReqBody)
|
var decoded ReqRespCapture
|
||||||
assert.Equal(t, []byte("test response"), retrieved.RespBody)
|
err := json.Unmarshal(retrieved, &decoded)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, decoded.ID)
|
||||||
|
assert.Equal(t, []byte("test request"), decoded.ReqBody)
|
||||||
|
assert.Equal(t, []byte("test response"), decoded.RespBody)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
t.Run("evicts oldest when exceeding max size", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
mm.maxCaptureSize = 100 // Set small limit for test
|
// Each full ReqRespCapture with 80 bytes random data compresses to ~185 bytes.
|
||||||
|
// 2 captures = ~370 bytes, 3 captures = ~555 bytes. Set limit so only 2 fit.
|
||||||
|
mm.maxCaptureSize = 450
|
||||||
|
|
||||||
// Add captures that will exceed the limit
|
// Use random-looking data that doesn't compress well with zstd
|
||||||
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 40)}
|
rng := rand.New(rand.NewSource(42))
|
||||||
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 40)}
|
capture1 := ReqRespCapture{ID: 0, ReqBody: make([]byte, 80)}
|
||||||
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 40)}
|
rng.Read(capture1.ReqBody)
|
||||||
|
capture2 := ReqRespCapture{ID: 1, ReqBody: make([]byte, 80)}
|
||||||
|
rng.Read(capture2.ReqBody)
|
||||||
|
capture3 := ReqRespCapture{ID: 2, ReqBody: make([]byte, 80)}
|
||||||
|
rng.Read(capture3.ReqBody)
|
||||||
|
|
||||||
mm.addCapture(capture1)
|
mm.addCapture(capture1)
|
||||||
mm.addCapture(capture2)
|
mm.addCapture(capture2)
|
||||||
// Adding capture3 should evict capture1
|
// Adding capture3 should evict capture1
|
||||||
mm.addCapture(capture3)
|
mm.addCapture(capture3)
|
||||||
|
|
||||||
assert.Nil(t, mm.getCaptureByID(0), "capture 0 should be evicted")
|
assert.Nil(t, mm.getCaptureByID(0, true), "capture 0 should be evicted")
|
||||||
assert.NotNil(t, mm.getCaptureByID(1), "capture 1 should exist")
|
retrieved := mm.getCaptureByID(1, true)
|
||||||
assert.NotNil(t, mm.getCaptureByID(2), "capture 2 should exist")
|
assert.NotNil(t, retrieved, "capture 1 should exist")
|
||||||
|
retrieved = mm.getCaptureByID(2, true)
|
||||||
|
assert.NotNil(t, retrieved, "capture 2 should exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("skips capture larger than max size", func(t *testing.T) {
|
t.Run("skips capture larger than max size", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
mm.maxCaptureSize = 100
|
mm.maxCaptureSize = 100
|
||||||
|
|
||||||
// Add a capture larger than max
|
// Use random data that doesn't compress well to create an oversized capture
|
||||||
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 200)}
|
rng := rand.New(rand.NewSource(99))
|
||||||
|
largeCapture := ReqRespCapture{ID: 0, ReqBody: make([]byte, 300)}
|
||||||
|
rng.Read(largeCapture.ReqBody)
|
||||||
mm.addCapture(largeCapture)
|
mm.addCapture(largeCapture)
|
||||||
|
|
||||||
assert.Nil(t, mm.getCaptureByID(0), "oversized capture should not be stored")
|
assert.Nil(t, mm.getCaptureByID(0, false), "oversized capture should not be stored")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1044,21 +1058,44 @@ func TestMetricsMonitor_GetCaptureByID(t *testing.T) {
|
|||||||
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
t.Run("returns nil for non-existent ID", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
assert.Nil(t, mm.getCaptureByID(999))
|
assert.Nil(t, mm.getCaptureByID(999, false))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns capture by ID", func(t *testing.T) {
|
t.Run("returns decompressed capture by ID", func(t *testing.T) {
|
||||||
mm := newMetricsMonitor(testLogger, 10, 5)
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
capture := ReqRespCapture{
|
capture := ReqRespCapture{
|
||||||
ID: 42,
|
ID: 42,
|
||||||
ReqBody: []byte("test"),
|
ReqBody: []byte("test request"),
|
||||||
|
RespBody: []byte("test response"),
|
||||||
}
|
}
|
||||||
mm.addCapture(capture)
|
mm.addCapture(capture)
|
||||||
|
|
||||||
retrieved := mm.getCaptureByID(42)
|
retrieved := mm.getCaptureByID(42, true)
|
||||||
assert.NotNil(t, retrieved)
|
assert.NotNil(t, retrieved)
|
||||||
assert.Equal(t, 42, retrieved.ID)
|
|
||||||
|
var decoded ReqRespCapture
|
||||||
|
err := json.Unmarshal(retrieved, &decoded)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 42, decoded.ID)
|
||||||
|
assert.Equal(t, []byte("test request"), decoded.ReqBody)
|
||||||
|
assert.Equal(t, []byte("test response"), decoded.RespBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns compressed bytes when decompress=false", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10, 5)
|
||||||
|
|
||||||
|
capture := ReqRespCapture{
|
||||||
|
ID: 42,
|
||||||
|
ReqBody: []byte("test request body"),
|
||||||
|
RespBody: []byte("test response body"),
|
||||||
|
}
|
||||||
|
mm.addCapture(capture)
|
||||||
|
|
||||||
|
compressed := mm.getCaptureByID(42, false)
|
||||||
|
assert.NotNil(t, compressed)
|
||||||
|
// Compressed data should not be valid JSON (it's zstd-compressed)
|
||||||
|
assert.False(t, gjson.ValidBytes(compressed))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1135,9 +1172,13 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
|||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
metricID := metrics[0].ID
|
metricID := metrics[0].ID
|
||||||
|
|
||||||
// Check capture was stored with same ID
|
// Check capture was stored with same ID (decompressed)
|
||||||
capture := mm.getCaptureByID(metricID)
|
captureData := mm.getCaptureByID(metricID, true)
|
||||||
assert.NotNil(t, capture)
|
assert.NotNil(t, captureData)
|
||||||
|
|
||||||
|
var capture ReqRespCapture
|
||||||
|
err = json.Unmarshal(captureData, &capture)
|
||||||
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, metricID, capture.ID)
|
assert.Equal(t, metricID, capture.ID)
|
||||||
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
assert.Equal(t, []byte(requestBody), capture.ReqBody)
|
||||||
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
assert.Equal(t, []byte(responseBody), capture.RespBody)
|
||||||
@@ -1173,7 +1214,7 @@ func TestMetricsMonitor_WrapHandler_Capture(t *testing.T) {
|
|||||||
assert.Equal(t, 1, len(metrics))
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
|
||||||
// But no capture
|
// But no capture
|
||||||
capture := mm.getCaptureByID(metrics[0].ID)
|
capture := mm.getCaptureByID(metrics[0].ID, false)
|
||||||
assert.Nil(t, capture)
|
assert.Nil(t, capture)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
+62
-4
@@ -77,6 +77,9 @@ type Process struct {
|
|||||||
// used for testing to override the default value
|
// used for testing to override the default value
|
||||||
gracefulStopTimeout time.Duration
|
gracefulStopTimeout time.Duration
|
||||||
|
|
||||||
|
// used for testing to bypass subprocess and reverse proxy
|
||||||
|
testHandler http.Handler
|
||||||
|
|
||||||
// track the number of failed starts
|
// track the number of failed starts
|
||||||
failedStartCount int
|
failedStartCount int
|
||||||
}
|
}
|
||||||
@@ -236,6 +239,49 @@ func (p *Process) forceState(newState ProcessState) {
|
|||||||
// at any time.
|
// at any time.
|
||||||
func (p *Process) start() error {
|
func (p *Process) start() error {
|
||||||
|
|
||||||
|
// test-only fast path: skip subprocess, health check, and TTL goroutine
|
||||||
|
if p.testHandler != nil {
|
||||||
|
if curState, err := p.swapState(StateStopped, StateStarting); err != nil {
|
||||||
|
if err == ErrExpectedStateMismatch {
|
||||||
|
if curState == StateStarting {
|
||||||
|
p.waitStarting.Wait()
|
||||||
|
curState = p.CurrentState()
|
||||||
|
if curState == StateReady {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("process was already starting but wound up in state %v", curState)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("process was in state %v when start() was called", curState)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err)
|
||||||
|
}
|
||||||
|
defer p.waitStarting.Done()
|
||||||
|
|
||||||
|
// Mimic the real stop path: cancelUpstream transitions
|
||||||
|
// StateStopping -> StateStopped and closes cmdWaitChan,
|
||||||
|
// matching what waitForCmd does for real subprocesses.
|
||||||
|
ch := make(chan struct{})
|
||||||
|
p.cmdMutex.Lock()
|
||||||
|
p.cancelUpstream = func() {
|
||||||
|
if curState := p.CurrentState(); curState == StateStopping {
|
||||||
|
if _, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
|
p.forceState(StateStopped)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
p.forceState(StateStopped)
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
p.cmdWaitChan = ch
|
||||||
|
p.cmdMutex.Unlock()
|
||||||
|
|
||||||
|
if curState, err := p.swapState(StateStarting, StateReady); err != nil {
|
||||||
|
return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
|
||||||
|
}
|
||||||
|
p.failedStartCount = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if p.config.Proxy == "" {
|
if p.config.Proxy == "" {
|
||||||
return fmt.Errorf("can not start(), upstream proxy missing")
|
return fmt.Errorf("can not start(), upstream proxy missing")
|
||||||
}
|
}
|
||||||
@@ -386,7 +432,10 @@ func (p *Process) start() error {
|
|||||||
|
|
||||||
// Stop will wait for inflight requests to complete before stopping the process.
|
// Stop will wait for inflight requests to complete before stopping the process.
|
||||||
func (p *Process) Stop() {
|
func (p *Process) Stop() {
|
||||||
|
|
||||||
|
// guard to prevent multiple goroutines from stopping
|
||||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
if !isValidTransition(p.CurrentState(), StateStopping) {
|
||||||
|
p.proxyLogger.Debugf("<%s> Stop() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,13 +448,17 @@ func (p *Process) Stop() {
|
|||||||
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
|
||||||
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
|
||||||
func (p *Process) StopImmediately() {
|
func (p *Process) StopImmediately() {
|
||||||
if !isValidTransition(p.CurrentState(), StateStopping) {
|
|
||||||
|
// guard to prevent multiple goroutines from stopping the process
|
||||||
|
enterState := p.CurrentState()
|
||||||
|
if !isValidTransition(enterState, StateStopping) {
|
||||||
|
p.proxyLogger.Debugf("<%s> StopImmediate() suppressing invalid transition from %s to StateStopping", p.ID, p.CurrentState())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
|
p.proxyLogger.Debugf("<%s> Stopping process, enter state: %s", p.ID, enterState)
|
||||||
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
|
if curState, err := p.swapState(enterState, StateStopping); err != nil {
|
||||||
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
|
p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v, current state: %v", p.ID, enterState, err, curState)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -577,6 +630,11 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
if !srw.waitForCompletion(completionTimeout) {
|
if !srw.waitForCompletion(completionTimeout) {
|
||||||
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.testHandler != nil {
|
||||||
|
p.testHandler.ServeHTTP(w, r)
|
||||||
|
} else if srw != nil {
|
||||||
p.reverseProxy.ServeHTTP(srw, r)
|
p.reverseProxy.ServeHTTP(srw, r)
|
||||||
} else {
|
} else {
|
||||||
p.reverseProxy.ServeHTTP(w, r)
|
p.reverseProxy.ServeHTTP(w, r)
|
||||||
|
|||||||
@@ -24,6 +24,22 @@ type ProcessGroup struct {
|
|||||||
// map of current processes
|
// map of current processes
|
||||||
processes map[string]*Process
|
processes map[string]*Process
|
||||||
lastUsedProcess string
|
lastUsedProcess string
|
||||||
|
|
||||||
|
// inflight tracks fast-path requests (requests for the already-selected
|
||||||
|
// model in a swap group). Fast-path requests Add(1) while holding pg.Lock
|
||||||
|
// and Done() on completion; a concurrent swap request calls inflight.Wait()
|
||||||
|
// under pg.Lock before stopping the current process. Without this tracking,
|
||||||
|
// a fast-path request that has released pg.Lock but has not yet called
|
||||||
|
// Process.inFlightRequests.Add(1) races with Stop()'s Wait() and can be
|
||||||
|
// killed mid-request.
|
||||||
|
inflight sync.WaitGroup
|
||||||
|
|
||||||
|
// testDelayFastPath is a test-only hook that, when non-nil, is invoked in
|
||||||
|
// the fast path after pg.Lock is released but before the request is
|
||||||
|
// dispatched to Process.ProxyRequest. Tests use it to park a fast-path
|
||||||
|
// request at the exact race window to deterministically reproduce the
|
||||||
|
// fast-path vs swap race.
|
||||||
|
testDelayFastPath func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
|
||||||
@@ -64,6 +80,13 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
|
|||||||
pg.Lock()
|
pg.Lock()
|
||||||
if pg.lastUsedProcess != modelID {
|
if pg.lastUsedProcess != modelID {
|
||||||
|
|
||||||
|
// Wait for in-flight fast-path requests to drain before stopping
|
||||||
|
// the previous process. Without this, a fast-path request that has
|
||||||
|
// released pg.Lock but has not yet incremented
|
||||||
|
// Process.inFlightRequests races with Stop() and can be killed
|
||||||
|
// mid-request.
|
||||||
|
pg.inflight.Wait()
|
||||||
|
|
||||||
// is there something already running?
|
// is there something already running?
|
||||||
if pg.lastUsedProcess != "" {
|
if pg.lastUsedProcess != "" {
|
||||||
pg.processes[pg.lastUsedProcess].Stop()
|
pg.processes[pg.lastUsedProcess].Stop()
|
||||||
@@ -78,7 +101,16 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
|
|||||||
pg.Unlock()
|
pg.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fast path: register this request in inflight before releasing
|
||||||
|
// pg.Lock so a concurrent swap will wait for it to complete.
|
||||||
|
pg.inflight.Add(1)
|
||||||
|
defer pg.inflight.Done()
|
||||||
pg.Unlock()
|
pg.Unlock()
|
||||||
|
|
||||||
|
if pg.testDelayFastPath != nil {
|
||||||
|
pg.testDelayFastPath()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pg.processes[modelID].ProxyRequest(writer, request)
|
pg.processes[modelID].ProxyRequest(writer, request)
|
||||||
@@ -123,6 +155,10 @@ func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
|||||||
pg.Lock()
|
pg.Lock()
|
||||||
defer pg.Unlock()
|
defer pg.Unlock()
|
||||||
|
|
||||||
|
if strategy != StopImmediately {
|
||||||
|
pg.inflight.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
if len(pg.processes) == 0 {
|
if len(pg.processes) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,14 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/mostlygeek/llama-swap/proxy/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
|
||||||
@@ -95,6 +98,229 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath verifies that a swap
|
||||||
|
// request cannot stop the current process while a fast-path request (for the
|
||||||
|
// already-selected model) is in flight. Without ProcessGroup-level inflight
|
||||||
|
// tracking, a fast-path request that has released pg.Lock but has not yet
|
||||||
|
// incremented Process.inFlightRequests races with Stop()'s Wait() and the
|
||||||
|
// process is killed mid-request.
|
||||||
|
func TestProcessGroup_ProxyRequestSwapRaceAgainstFastPath(t *testing.T) {
|
||||||
|
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Members: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||||
|
defer pg.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
// Bypass real subprocesses so the test is fast and deterministic.
|
||||||
|
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||||
|
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||||
|
|
||||||
|
// Prime: run a request through model1 via the swap path so that
|
||||||
|
// lastUsedProcess == "model1" and subsequent model1 requests take the
|
||||||
|
// fast path.
|
||||||
|
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
primeW := httptest.NewRecorder()
|
||||||
|
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||||
|
require.Equal(t, http.StatusOK, primeW.Code)
|
||||||
|
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||||
|
require.Equal(t, StateStopped, pg.processes["model2"].CurrentState())
|
||||||
|
|
||||||
|
// Fast-path hook: signal arrival at the race window, then wait for
|
||||||
|
// release. This parks R2 deterministically at the point where pg.Lock
|
||||||
|
// has been released but Process.inFlightRequests has not yet been
|
||||||
|
// incremented — the exact window the race exploits.
|
||||||
|
r2Reached := make(chan struct{})
|
||||||
|
r2Release := make(chan struct{})
|
||||||
|
pg.testDelayFastPath = func() {
|
||||||
|
close(r2Reached)
|
||||||
|
<-r2Release
|
||||||
|
}
|
||||||
|
|
||||||
|
// R2: fast-path request for model1. Will pause at the test hook.
|
||||||
|
r2Done := make(chan struct{})
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r2Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Deterministically wait for R2 to reach the race window.
|
||||||
|
<-r2Reached
|
||||||
|
|
||||||
|
// R3: swap request for model2. Must wait for R2 to finish before touching
|
||||||
|
// model1, otherwise model1 gets killed mid-request.
|
||||||
|
r3Done := make(chan struct{})
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r3Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, pg.ProxyRequest("model2", w3, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Spin until R3 has acquired pg.Lock and entered the swap critical
|
||||||
|
// section. In the fixed code, R3 then blocks on pg.inflight.Wait() while
|
||||||
|
// still holding the lock, so TryLock keeps failing.
|
||||||
|
for pg.TryLock() {
|
||||||
|
pg.Unlock()
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounded poll: give R3 a chance to demonstrate the bug by mutating
|
||||||
|
// state. In the fixed code, R3 is blocked on pg.inflight.Wait() and
|
||||||
|
// nothing changes, so we wait the full window. In the buggy code, R3
|
||||||
|
// will Stop() model1 and start serving via model2 within microseconds —
|
||||||
|
// we exit early once the mutation is observable.
|
||||||
|
deadline := time.Now().Add(100 * time.Millisecond)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if pg.processes["model1"].CurrentState() != StateReady ||
|
||||||
|
pg.processes["model2"].CurrentState() != StateStopped {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
done := false
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
done = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invariant: R3 must be blocked while R2 is still in flight.
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
t.Fatal("swap completed while fast-path request was still in flight — race not prevented")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||||
|
"model1 must stay Ready while a fast-path request is in flight")
|
||||||
|
assert.Equal(t, StateStopped, pg.processes["model2"].CurrentState(),
|
||||||
|
"model2 must not be started until R2 finishes and model1 is swapped out")
|
||||||
|
|
||||||
|
// Release R2 and let both requests finish.
|
||||||
|
close(r2Release)
|
||||||
|
<-r2Done
|
||||||
|
<-r3Done
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w2.Code)
|
||||||
|
assert.Contains(t, w2.Body.String(), "model1")
|
||||||
|
assert.Equal(t, http.StatusOK, w3.Code)
|
||||||
|
assert.Contains(t, w3.Body.String(), "model2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessGroup_StopProcessesWaitsForInflight verifies that StopProcesses
|
||||||
|
// (called externally, e.g. from ProxyManager.swapProcessGroup) cannot stop a
|
||||||
|
// process while a fast-path ProxyRequest is in the [pg.Unlock,
|
||||||
|
// Process.inFlightRequests.Add(1)] window. Without pg.inflight.Wait() in
|
||||||
|
// StopProcesses, the external caller bypasses the inflight guard and kills the
|
||||||
|
// process mid-request.
|
||||||
|
func TestProcessGroup_StopProcessesWaitsForInflight(t *testing.T) {
|
||||||
|
cfg := config.AddDefaultGroupToConfig(config.Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]config.ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Groups: map[string]config.GroupConfig{
|
||||||
|
"G1": {
|
||||||
|
Swap: true,
|
||||||
|
Members: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
pg := NewProcessGroup("G1", cfg, testLogger, testLogger)
|
||||||
|
defer pg.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
pg.processes["model1"].testHandler = newTestHandler("model1")
|
||||||
|
pg.processes["model2"].testHandler = newTestHandler("model2")
|
||||||
|
|
||||||
|
// Prime: model1 is active so subsequent model1 requests take the fast path.
|
||||||
|
primeReq := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
primeW := httptest.NewRecorder()
|
||||||
|
require.NoError(t, pg.ProxyRequest("model1", primeW, primeReq))
|
||||||
|
require.Equal(t, http.StatusOK, primeW.Code)
|
||||||
|
require.Equal(t, StateReady, pg.processes["model1"].CurrentState())
|
||||||
|
|
||||||
|
// Park a fast-path request at the race window.
|
||||||
|
r2Reached := make(chan struct{})
|
||||||
|
r2Release := make(chan struct{})
|
||||||
|
pg.testDelayFastPath = func() {
|
||||||
|
close(r2Reached)
|
||||||
|
<-r2Release
|
||||||
|
}
|
||||||
|
|
||||||
|
r2Done := make(chan struct{})
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
defer close(r2Done)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
assert.NoError(t, pg.ProxyRequest("model1", w2, req))
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-r2Reached
|
||||||
|
|
||||||
|
// Simulate an external caller (e.g. ProxyManager.swapProcessGroup) stopping
|
||||||
|
// the group while a fast-path request is in flight.
|
||||||
|
r3Done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(r3Done)
|
||||||
|
pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Spin until StopProcesses has acquired pg.Lock.
|
||||||
|
for pg.TryLock() {
|
||||||
|
pg.Unlock()
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounded poll: in the fixed code StopProcesses blocks on pg.inflight.Wait()
|
||||||
|
// and model1 stays Ready. In the buggy code it proceeds immediately and
|
||||||
|
// kills model1.
|
||||||
|
deadline := time.Now().Add(100 * time.Millisecond)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if pg.processes["model1"].CurrentState() != StateReady {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
goto done
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
done:
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-r3Done:
|
||||||
|
t.Fatal("StopProcesses completed while a fast-path request was still in flight — race not prevented")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.Equal(t, StateReady, pg.processes["model1"].CurrentState(),
|
||||||
|
"model1 must stay Ready while a fast-path request is in flight")
|
||||||
|
|
||||||
|
close(r2Release)
|
||||||
|
<-r2Done
|
||||||
|
<-r3Done
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w2.Code)
|
||||||
|
assert.Contains(t, w2.Body.String(), "model1")
|
||||||
|
}
|
||||||
|
|
||||||
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
||||||
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
||||||
defer pg.StopProcesses(StopWaitForInflightRequest)
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
||||||
|
|||||||
@@ -290,11 +290,26 @@ func (pm *ProxyManager) apiGetCapture(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
capture := pm.metricsMonitor.getCaptureByID(id)
|
data, exists := pm.metricsMonitor.getCompressedBytes(id)
|
||||||
if capture == nil {
|
if !exists {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "capture not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, capture)
|
c.Header("Vary", "Accept-Encoding")
|
||||||
|
|
||||||
|
// ¯\_(ツ)_/¯ quality weights are too fancy for us anyway
|
||||||
|
hasZstd := strings.Contains(c.GetHeader("Accept-Encoding"), "zstd")
|
||||||
|
|
||||||
|
if hasZstd {
|
||||||
|
c.Header("Content-Encoding", "zstd")
|
||||||
|
c.Data(http.StatusOK, "application/json", data)
|
||||||
|
} else {
|
||||||
|
decompressed, err := decompressCapture(data)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decompress capture"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/json", decompressed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+322
-339
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,97 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { inFlightRequests, metrics } from "../stores/api";
|
||||||
|
import { persistentStore } from "../stores/persistent";
|
||||||
|
import { calculateHistogramData } from "../lib/histogram";
|
||||||
|
import TokenHistogram from "./TokenHistogram.svelte";
|
||||||
|
|
||||||
|
const nf = new Intl.NumberFormat();
|
||||||
|
const histogramCollapsed = persistentStore<boolean>("activity-histogram-collapsed", false);
|
||||||
|
|
||||||
|
let stats = $derived.by(() => {
|
||||||
|
const totalRequests = $metrics.length;
|
||||||
|
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
||||||
|
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
||||||
|
const totalCacheTokens = $metrics.reduce((sum, m) => sum + m.cache_tokens, 0);
|
||||||
|
|
||||||
|
const promptPerSecond = $metrics.filter((m) => m.prompt_per_second > 0).map((m) => m.prompt_per_second);
|
||||||
|
|
||||||
|
const tokensPerSecond = $metrics.filter((m) => m.tokens_per_second > 0).map((m) => m.tokens_per_second);
|
||||||
|
|
||||||
|
const promptHistogramData =
|
||||||
|
promptPerSecond.length > 0 ? calculateHistogramData(promptPerSecond) : null;
|
||||||
|
|
||||||
|
const genHistogramData =
|
||||||
|
tokensPerSecond.length > 0 ? calculateHistogramData(tokensPerSecond) : null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
totalRequests,
|
||||||
|
totalInputTokens,
|
||||||
|
totalOutputTokens,
|
||||||
|
totalCacheTokens,
|
||||||
|
inFlightRequests: $inFlightRequests,
|
||||||
|
promptHistogramData,
|
||||||
|
genHistogramData,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="card relative p-3">
|
||||||
|
<button
|
||||||
|
class="absolute top-2 right-2 w-6 h-6 flex items-center justify-center rounded-full border border-gray-300 dark:border-gray-600 text-gray-400 dark:text-gray-500 hover:text-gray-600 dark:hover:text-gray-300 hover:border-gray-400 dark:hover:border-gray-400 transition-colors"
|
||||||
|
onclick={() => ($histogramCollapsed = !$histogramCollapsed)}
|
||||||
|
title={$histogramCollapsed ? "Show histograms" : "Hide histograms"}
|
||||||
|
>
|
||||||
|
{#if $histogramCollapsed}
|
||||||
|
<svg class="w-3.5 h-3.5" viewBox="0 0 16 16" fill="currentColor">
|
||||||
|
<path d="M4.5 6l3.5 4 3.5-4H4.5z" />
|
||||||
|
</svg>
|
||||||
|
{:else}
|
||||||
|
<svg class="w-3 h-3" viewBox="0 0 16 16" fill="currentColor">
|
||||||
|
<path d="M3.5 3.5l9 9M12.5 3.5l-9 9" stroke="currentColor" stroke-width="2" stroke-linecap="round" fill="none" />
|
||||||
|
</svg>
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
{#if !$histogramCollapsed}
|
||||||
|
<div class="flex flex-col sm:flex-row gap-6 mb-3">
|
||||||
|
<div class="w-full sm:w-1/2 min-w-0">
|
||||||
|
<div class="text-sm font-medium text-gray-500 dark:text-gray-400 mb-1">Prompt Processing</div>
|
||||||
|
{#if stats.promptHistogramData}
|
||||||
|
<TokenHistogram
|
||||||
|
data={stats.promptHistogramData}
|
||||||
|
unit="prompt tokens/sec"
|
||||||
|
colorClass="text-amber-500 dark:text-amber-400"
|
||||||
|
/>
|
||||||
|
{:else}
|
||||||
|
<div class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">No prompt speed data yet</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
<div class="w-full sm:w-1/2 min-w-0">
|
||||||
|
<div class="text-sm font-medium text-gray-500 dark:text-gray-400 mb-1">Token Generation</div>
|
||||||
|
{#if stats.genHistogramData}
|
||||||
|
<TokenHistogram data={stats.genHistogramData} unit="tokens/sec" />
|
||||||
|
{:else}
|
||||||
|
<div class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">No generation speed data yet</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
<div class="grid grid-cols-4 gap-x-6 gap-y-1 text-sm">
|
||||||
|
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Requests</div>
|
||||||
|
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Cached</div>
|
||||||
|
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Processed</div>
|
||||||
|
<div class="text-xs uppercase tracking-wider text-gray-500 dark:text-gray-400">Generated</div>
|
||||||
|
<div class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
<span class="font-semibold">{nf.format(stats.totalRequests)}</span> completed,
|
||||||
|
<span class="font-semibold">{nf.format(stats.inFlightRequests)}</span> waiting
|
||||||
|
</div>
|
||||||
|
<div class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
<span class="font-semibold">{nf.format(stats.totalCacheTokens)}</span> tokens
|
||||||
|
</div>
|
||||||
|
<div class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
<span class="font-semibold">{nf.format(stats.totalInputTokens)}</span> tokens
|
||||||
|
</div>
|
||||||
|
<div class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
<span class="font-semibold">{nf.format(stats.totalOutputTokens)}</span> tokens
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
@@ -106,6 +106,7 @@
|
|||||||
const delta = parsed.choices?.[0]?.delta;
|
const delta = parsed.choices?.[0]?.delta;
|
||||||
if (delta?.content) result.content += delta.content;
|
if (delta?.content) result.content += delta.content;
|
||||||
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
|
if (delta?.reasoning_content) result.reasoning += delta.reasoning_content;
|
||||||
|
if (delta?.reasoning) result.reasoning += delta.reasoning;
|
||||||
} catch {
|
} catch {
|
||||||
// skip unparseable lines
|
// skip unparseable lines
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@
|
|||||||
<a
|
<a
|
||||||
href="/"
|
href="/"
|
||||||
use:link
|
use:link
|
||||||
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
|
class="p-1 whitespace-nowrap {isActive('/', $currentRoute) ? 'font-semibold underline underline-offset-4' : ''} {$playgroundActivity ? 'activity-link' : 'text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100'}"
|
||||||
>
|
>
|
||||||
Playground
|
Playground
|
||||||
</a>
|
</a>
|
||||||
@@ -59,6 +59,8 @@
|
|||||||
use:link
|
use:link
|
||||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||||
class:font-semibold={isActive("/models", $currentRoute)}
|
class:font-semibold={isActive("/models", $currentRoute)}
|
||||||
|
class:underline={isActive("/models", $currentRoute)}
|
||||||
|
class:underline-offset-4={isActive("/models", $currentRoute)}
|
||||||
>
|
>
|
||||||
Models
|
Models
|
||||||
</a>
|
</a>
|
||||||
@@ -67,6 +69,8 @@
|
|||||||
use:link
|
use:link
|
||||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||||
class:font-semibold={isActive("/activity", $currentRoute)}
|
class:font-semibold={isActive("/activity", $currentRoute)}
|
||||||
|
class:underline={isActive("/activity", $currentRoute)}
|
||||||
|
class:underline-offset-4={isActive("/activity", $currentRoute)}
|
||||||
>
|
>
|
||||||
Activity
|
Activity
|
||||||
</a>
|
</a>
|
||||||
@@ -75,6 +79,8 @@
|
|||||||
use:link
|
use:link
|
||||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||||
class:font-semibold={isActive("/logs", $currentRoute)}
|
class:font-semibold={isActive("/logs", $currentRoute)}
|
||||||
|
class:underline={isActive("/logs", $currentRoute)}
|
||||||
|
class:underline-offset-4={isActive("/logs", $currentRoute)}
|
||||||
>
|
>
|
||||||
Logs
|
Logs
|
||||||
</a>
|
</a>
|
||||||
|
|||||||
@@ -1,167 +0,0 @@
|
|||||||
<script lang="ts">
|
|
||||||
import { inFlightRequests, metrics } from "../stores/api";
|
|
||||||
import TokenHistogram from "./TokenHistogram.svelte";
|
|
||||||
|
|
||||||
interface HistogramData {
|
|
||||||
bins: number[];
|
|
||||||
min: number;
|
|
||||||
max: number;
|
|
||||||
binSize: number;
|
|
||||||
p99: number;
|
|
||||||
p95: number;
|
|
||||||
p50: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
let stats = $derived.by(() => {
|
|
||||||
const totalRequests = $metrics.length;
|
|
||||||
if (totalRequests === 0) {
|
|
||||||
return {
|
|
||||||
totalRequests: 0,
|
|
||||||
totalInputTokens: 0,
|
|
||||||
totalOutputTokens: 0,
|
|
||||||
inFlightRequests: $inFlightRequests,
|
|
||||||
tokenStats: { p99: "0", p95: "0", p50: "0" },
|
|
||||||
histogramData: null,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const totalInputTokens = $metrics.reduce((sum, m) => sum + m.input_tokens, 0);
|
|
||||||
const totalOutputTokens = $metrics.reduce((sum, m) => sum + m.output_tokens, 0);
|
|
||||||
|
|
||||||
// Calculate token statistics using output_tokens and duration_ms
|
|
||||||
const validMetrics = $metrics.filter((m) => m.duration_ms > 0 && m.output_tokens > 0);
|
|
||||||
if (validMetrics.length === 0) {
|
|
||||||
return {
|
|
||||||
totalRequests,
|
|
||||||
totalInputTokens,
|
|
||||||
totalOutputTokens,
|
|
||||||
inFlightRequests: $inFlightRequests,
|
|
||||||
tokenStats: { p99: "0", p95: "0", p50: "0" },
|
|
||||||
histogramData: null,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate tokens/second for each valid metric
|
|
||||||
const tokensPerSecond = validMetrics.map((m) => m.output_tokens / (m.duration_ms / 1000));
|
|
||||||
|
|
||||||
// Sort for percentile calculation
|
|
||||||
const sortedTokensPerSecond = [...tokensPerSecond].sort((a, b) => a - b);
|
|
||||||
|
|
||||||
const p99 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.99)];
|
|
||||||
const p95 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.95)];
|
|
||||||
const p50 = sortedTokensPerSecond[Math.floor(sortedTokensPerSecond.length * 0.5)];
|
|
||||||
|
|
||||||
// Create histogram data
|
|
||||||
const min = Math.min(...tokensPerSecond);
|
|
||||||
const max = Math.max(...tokensPerSecond);
|
|
||||||
const binCount = Math.min(30, Math.max(10, Math.floor(tokensPerSecond.length / 5)));
|
|
||||||
const binSize = (max - min) / binCount;
|
|
||||||
|
|
||||||
const bins = Array(binCount).fill(0);
|
|
||||||
tokensPerSecond.forEach((value) => {
|
|
||||||
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
|
|
||||||
bins[binIndex]++;
|
|
||||||
});
|
|
||||||
|
|
||||||
const histogramData: HistogramData = {
|
|
||||||
bins,
|
|
||||||
min,
|
|
||||||
max,
|
|
||||||
binSize,
|
|
||||||
p99,
|
|
||||||
p95,
|
|
||||||
p50,
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
|
||||||
totalRequests,
|
|
||||||
totalInputTokens,
|
|
||||||
totalOutputTokens,
|
|
||||||
inFlightRequests: $inFlightRequests,
|
|
||||||
tokenStats: {
|
|
||||||
p99: p99.toFixed(2),
|
|
||||||
p95: p95.toFixed(2),
|
|
||||||
p50: p50.toFixed(2),
|
|
||||||
},
|
|
||||||
histogramData,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
const nf = new Intl.NumberFormat();
|
|
||||||
</script>
|
|
||||||
|
|
||||||
<div class="card">
|
|
||||||
<div class="rounded-lg overflow-hidden border border-card-border-inner">
|
|
||||||
<table class="min-w-full divide-y divide-card-border-inner">
|
|
||||||
<thead class="bg-secondary">
|
|
||||||
<tr>
|
|
||||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain">Requests</th>
|
|
||||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
|
||||||
Processed
|
|
||||||
</th>
|
|
||||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
|
||||||
Generated
|
|
||||||
</th>
|
|
||||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-txtmain border-l border-card-border-inner">
|
|
||||||
Token Stats (tokens/sec)
|
|
||||||
</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
|
|
||||||
<tbody class="bg-surface divide-y divide-card-border-inner">
|
|
||||||
<tr class="hover:bg-secondary">
|
|
||||||
<td class="px-4 py-4 text-sm font-semibold text-gray-900 dark:text-white">
|
|
||||||
<div class="flex flex-col gap-1">
|
|
||||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Completed: {nf.format(stats.totalRequests)}</span>
|
|
||||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">Waiting: {nf.format(stats.inFlightRequests)}</span>
|
|
||||||
</div>
|
|
||||||
</td>
|
|
||||||
|
|
||||||
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
|
||||||
<div class="flex items-center gap-2">
|
|
||||||
<span class="text-sm font-medium">{nf.format(stats.totalInputTokens)}</span>
|
|
||||||
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
|
||||||
</div>
|
|
||||||
</td>
|
|
||||||
|
|
||||||
<td class="px-4 py-4 text-sm text-gray-700 dark:text-gray-300 border-l border-gray-200 dark:border-white/10">
|
|
||||||
<div class="flex items-center gap-2">
|
|
||||||
<span class="text-sm font-medium">{nf.format(stats.totalOutputTokens)}</span>
|
|
||||||
<span class="text-xs text-gray-500 dark:text-gray-400">tokens</span>
|
|
||||||
</div>
|
|
||||||
</td>
|
|
||||||
|
|
||||||
<td class="px-4 py-4 border-l border-gray-200 dark:border-white/10">
|
|
||||||
<div class="space-y-3">
|
|
||||||
<div class="grid grid-cols-3 gap-2 items-center">
|
|
||||||
<div class="text-center">
|
|
||||||
<div class="text-xs text-gray-500 dark:text-gray-400">P50</div>
|
|
||||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
|
||||||
{stats.tokenStats.p50}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="text-center">
|
|
||||||
<div class="text-xs text-gray-500 dark:text-gray-400">P95</div>
|
|
||||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
|
||||||
{stats.tokenStats.p95}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="text-center">
|
|
||||||
<div class="text-xs text-gray-500 dark:text-gray-400">P99</div>
|
|
||||||
<div class="mt-1 inline-block rounded-full bg-gray-100 dark:bg-white/5 px-3 py-1 text-sm font-semibold text-gray-800 dark:text-white">
|
|
||||||
{stats.tokenStats.p99}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{#if stats.histogramData}
|
|
||||||
<TokenHistogram data={stats.histogramData} />
|
|
||||||
{/if}
|
|
||||||
</div>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
@@ -1,23 +1,19 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
interface HistogramData {
|
import type { HistogramData } from "../lib/types";
|
||||||
bins: number[];
|
|
||||||
min: number;
|
|
||||||
max: number;
|
|
||||||
binSize: number;
|
|
||||||
p99: number;
|
|
||||||
p95: number;
|
|
||||||
p50: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface Props {
|
let {
|
||||||
|
data,
|
||||||
|
unit = "tokens/sec",
|
||||||
|
colorClass = "text-blue-500 dark:text-blue-400",
|
||||||
|
}: {
|
||||||
data: HistogramData;
|
data: HistogramData;
|
||||||
}
|
unit?: string;
|
||||||
|
colorClass?: string;
|
||||||
|
} = $props();
|
||||||
|
|
||||||
let { data }: Props = $props();
|
const height = 250;
|
||||||
|
const padding = { top: 30, right: 20, bottom: 40, left: 75 };
|
||||||
const height = 120;
|
const viewBoxWidth = 1200;
|
||||||
const padding = { top: 10, right: 15, bottom: 25, left: 45 };
|
|
||||||
const viewBoxWidth = 600;
|
|
||||||
const chartWidth = viewBoxWidth - padding.left - padding.right;
|
const chartWidth = viewBoxWidth - padding.left - padding.right;
|
||||||
const chartHeight = height - padding.top - padding.bottom;
|
const chartHeight = height - padding.top - padding.bottom;
|
||||||
|
|
||||||
@@ -43,6 +39,24 @@
|
|||||||
opacity="0.3"
|
opacity="0.3"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<!-- Y-axis ticks and labels -->
|
||||||
|
{#each [0, 0.5, 1] as fraction}
|
||||||
|
{@const tickCount = Math.round(maxCount * fraction)}
|
||||||
|
{@const tickY = height - padding.bottom - fraction * chartHeight}
|
||||||
|
<line
|
||||||
|
x1={padding.left - 8}
|
||||||
|
y1={tickY}
|
||||||
|
x2={padding.left}
|
||||||
|
y2={tickY}
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="1"
|
||||||
|
opacity="0.4"
|
||||||
|
/>
|
||||||
|
<text x={padding.left - 10} y={tickY + 10} font-size="26" fill="currentColor" opacity="0.8" text-anchor="end">
|
||||||
|
{tickCount}
|
||||||
|
</text>
|
||||||
|
{/each}
|
||||||
|
|
||||||
<!-- X-axis -->
|
<!-- X-axis -->
|
||||||
<line
|
<line
|
||||||
x1={padding.left}
|
x1={padding.left}
|
||||||
@@ -69,9 +83,9 @@
|
|||||||
height={barHeight}
|
height={barHeight}
|
||||||
fill="currentColor"
|
fill="currentColor"
|
||||||
opacity="0.6"
|
opacity="0.6"
|
||||||
class="text-blue-500 dark:text-blue-400 hover:opacity-90 transition-opacity cursor-pointer"
|
class="{colorClass} hover:opacity-90 transition-opacity cursor-pointer"
|
||||||
/>
|
/>
|
||||||
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} tokens/sec\nCount: ${count}`}</title>
|
<title>{`${binStart.toFixed(1)} - ${binEnd.toFixed(1)} ${unit}\nCount: ${count}`}</title>
|
||||||
</g>
|
</g>
|
||||||
{/each}
|
{/each}
|
||||||
|
|
||||||
@@ -113,17 +127,19 @@
|
|||||||
/>
|
/>
|
||||||
|
|
||||||
<!-- X-axis labels -->
|
<!-- X-axis labels -->
|
||||||
<text x={padding.left} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="start">
|
<text x={padding.left} y={height - 8} font-size="26" fill="currentColor" opacity="0.8" text-anchor="start">
|
||||||
{data.min.toFixed(1)}
|
{data.min.toFixed(1)}
|
||||||
</text>
|
</text>
|
||||||
|
|
||||||
<text x={viewBoxWidth - padding.right} y={height - 5} font-size="10" fill="currentColor" opacity="0.6" text-anchor="end">
|
<text
|
||||||
|
x={viewBoxWidth - padding.right}
|
||||||
|
y={height - 8}
|
||||||
|
font-size="26"
|
||||||
|
fill="currentColor"
|
||||||
|
opacity="0.8"
|
||||||
|
text-anchor="end"
|
||||||
|
>
|
||||||
{data.max.toFixed(1)}
|
{data.max.toFixed(1)}
|
||||||
</text>
|
</text>
|
||||||
|
|
||||||
<!-- X-axis label -->
|
|
||||||
<text x={padding.left + chartWidth / 2} y={height - 2} font-size="10" fill="currentColor" opacity="0.6" text-anchor="middle">
|
|
||||||
Tokens/Second Distribution
|
|
||||||
</text>
|
|
||||||
</svg>
|
</svg>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ function parseSSELine(line: string): StreamChunk | null {
|
|||||||
const parsed = JSON.parse(data);
|
const parsed = JSON.parse(data);
|
||||||
const delta = parsed.choices?.[0]?.delta;
|
const delta = parsed.choices?.[0]?.delta;
|
||||||
const content = delta?.content || "";
|
const content = delta?.content || "";
|
||||||
const reasoning_content = delta?.reasoning_content || "";
|
const reasoning_content = delta?.reasoning_content || delta?.reasoning || "";
|
||||||
|
|
||||||
if (content || reasoning_content) {
|
if (content || reasoning_content) {
|
||||||
return { content, reasoning_content, done: false };
|
return { content, reasoning_content, done: false };
|
||||||
|
|||||||
@@ -0,0 +1,167 @@
|
|||||||
|
import { describe, it, expect } from "vitest";
|
||||||
|
import { calculateHistogramData } from "./histogram";
|
||||||
|
|
||||||
|
describe("calculateHistogramData", () => {
|
||||||
|
describe("edge cases", () => {
|
||||||
|
it("returns null for empty input", () => {
|
||||||
|
expect(calculateHistogramData([])).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles single value", () => {
|
||||||
|
const result = calculateHistogramData([42]);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.bins).toEqual([1]);
|
||||||
|
expect(result!.min).toBe(42);
|
||||||
|
expect(result!.max).toBe(42);
|
||||||
|
expect(result!.binSize).toBe(0);
|
||||||
|
expect(result!.p50).toBe(42);
|
||||||
|
expect(result!.p95).toBe(42);
|
||||||
|
expect(result!.p99).toBe(42);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles all identical values", () => {
|
||||||
|
const result = calculateHistogramData([10, 10, 10, 10, 10]);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.bins).toEqual([5]);
|
||||||
|
expect(result!.min).toBe(10);
|
||||||
|
expect(result!.max).toBe(10);
|
||||||
|
expect(result!.binSize).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles two distinct values", () => {
|
||||||
|
const result = calculateHistogramData([10, 20]);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.min).toBe(10);
|
||||||
|
expect(result!.max).toBe(20);
|
||||||
|
expect(result!.p50).toBe(15);
|
||||||
|
const binSum = result!.bins.reduce((s, b) => s + b, 0);
|
||||||
|
expect(binSum).toBe(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("bin distribution", () => {
|
||||||
|
it("bins sum to total number of values", () => {
|
||||||
|
const values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
const binSum = result!.bins.reduce((s, b) => s + b, 0);
|
||||||
|
expect(binSum).toBe(values.length);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("distributes uniform values across bins", () => {
|
||||||
|
const values = Array.from({ length: 100 }, (_, i) => i);
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.bins.length).toBe(8);
|
||||||
|
const binSum = result!.bins.reduce((s, b) => s + b, 0);
|
||||||
|
expect(binSum).toBe(100);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("places values in correct bins", () => {
|
||||||
|
const values = [1, 1, 1, 5, 5, 9, 9, 9];
|
||||||
|
const result = calculateHistogramData(values, { minBins: 3, maxBins: 3 });
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.bins.length).toBe(3);
|
||||||
|
expect(result!.bins.reduce((s, b) => s + b, 0)).toBe(8);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles skewed distribution", () => {
|
||||||
|
const values = [1, 1, 1, 1, 1, 100];
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
const binSum = result!.bins.reduce((s, b) => s + b, 0);
|
||||||
|
expect(binSum).toBe(6);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("percentiles", () => {
|
||||||
|
it("calculates correct p50 for even-length array", () => {
|
||||||
|
const values = [1, 2, 3, 4];
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.p50).toBe(2.5);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calculates correct p50 for odd-length array", () => {
|
||||||
|
const values = [1, 2, 3, 4, 5];
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.p50).toBe(3);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calculates p99 with interpolation", () => {
|
||||||
|
const values = Array.from({ length: 100 }, (_, i) => i + 1);
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.p99).toBeCloseTo(99.01);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calculates p95 with interpolation", () => {
|
||||||
|
const values = Array.from({ length: 100 }, (_, i) => i + 1);
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.p95).toBeCloseTo(95.05);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("percentiles are monotonically increasing", () => {
|
||||||
|
const values = Array.from({ length: 200 }, () => Math.random() * 100);
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result!.p50).toBeLessThanOrEqual(result!.p95);
|
||||||
|
expect(result!.p95).toBeLessThanOrEqual(result!.p99);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("bin count adaptation", () => {
|
||||||
|
it("uses minimum bins for small datasets", () => {
|
||||||
|
// n=8: sturges=4, clamped up to minBins=5
|
||||||
|
const values = Array.from({ length: 8 }, (_, i) => i);
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result!.bins.length).toBe(5);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("scales bins with dataset size", () => {
|
||||||
|
// n=100: sturges=8
|
||||||
|
const values = Array.from({ length: 100 }, (_, i) => i);
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result!.bins.length).toBe(8);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("caps bins at maximum", () => {
|
||||||
|
// n=1000: sturges=11, clamped down to maxBins=10
|
||||||
|
const values = Array.from({ length: 1000 }, (_, i) => i);
|
||||||
|
const result = calculateHistogramData(values, { minBins: 5, maxBins: 10 });
|
||||||
|
expect(result!.bins.length).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("respects custom options", () => {
|
||||||
|
// n=100: sturges=8, within [minBins=5, maxBins=10]
|
||||||
|
const values = Array.from({ length: 100 }, (_, i) => i);
|
||||||
|
const result = calculateHistogramData(values, { minBins: 5, maxBins: 10 });
|
||||||
|
expect(result!.bins.length).toBe(8);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("min and max", () => {
|
||||||
|
it("correctly identifies min and max", () => {
|
||||||
|
const values = [5, 3, 8, 1, 9, 2];
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result!.min).toBe(1);
|
||||||
|
expect(result!.max).toBe(9);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles negative values", () => {
|
||||||
|
const values = [-10, -5, 0, 5, 10];
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result!.min).toBe(-10);
|
||||||
|
expect(result!.max).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles floating point values", () => {
|
||||||
|
const values = [1.5, 2.7, 3.14, 0.5, 4.99];
|
||||||
|
const result = calculateHistogramData(values);
|
||||||
|
expect(result!.min).toBe(0.5);
|
||||||
|
expect(result!.max).toBe(4.99);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
import type { HistogramData } from "./types";
|
||||||
|
|
||||||
|
export interface HistogramOptions {
|
||||||
|
minBins?: number;
|
||||||
|
maxBins?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEFAULT_OPTIONS: HistogramOptions = {
|
||||||
|
minBins: 5,
|
||||||
|
maxBins: 20,
|
||||||
|
};
|
||||||
|
|
||||||
|
function percentile(sorted: number[], p: number): number {
|
||||||
|
if (sorted.length === 0) return 0;
|
||||||
|
if (sorted.length === 1) return sorted[0];
|
||||||
|
|
||||||
|
const rank = (p / 100) * (sorted.length - 1);
|
||||||
|
const lower = Math.floor(rank);
|
||||||
|
const upper = Math.ceil(rank);
|
||||||
|
const fraction = rank - lower;
|
||||||
|
|
||||||
|
return sorted[lower] + fraction * (sorted[upper] - sorted[lower]);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function calculateHistogramData(
|
||||||
|
values: number[],
|
||||||
|
options: HistogramOptions = DEFAULT_OPTIONS,
|
||||||
|
): HistogramData | null {
|
||||||
|
if (values.length === 0) return null;
|
||||||
|
|
||||||
|
const sorted = [...values].sort((a, b) => a - b);
|
||||||
|
const min = sorted[0];
|
||||||
|
const max = sorted[sorted.length - 1];
|
||||||
|
|
||||||
|
const p50 = percentile(sorted, 50);
|
||||||
|
const p95 = percentile(sorted, 95);
|
||||||
|
const p99 = percentile(sorted, 99);
|
||||||
|
|
||||||
|
if (min === max) {
|
||||||
|
return {
|
||||||
|
bins: [values.length],
|
||||||
|
min,
|
||||||
|
max,
|
||||||
|
binSize: 0,
|
||||||
|
p50,
|
||||||
|
p95,
|
||||||
|
p99,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const { minBins = 5, maxBins = 20 } = options;
|
||||||
|
const sturges = Math.ceil(Math.log2(values.length)) + 1;
|
||||||
|
const binCount = Math.min(maxBins, Math.max(minBins, sturges));
|
||||||
|
const binSize = (max - min) / binCount;
|
||||||
|
|
||||||
|
const bins = new Array(binCount).fill(0);
|
||||||
|
for (const value of values) {
|
||||||
|
const binIndex = Math.min(Math.floor((value - min) / binSize), binCount - 1);
|
||||||
|
bins[binIndex]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
bins,
|
||||||
|
min,
|
||||||
|
max,
|
||||||
|
binSize,
|
||||||
|
p50,
|
||||||
|
p95,
|
||||||
|
p99,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -48,6 +48,16 @@ export interface APIEventEnvelope {
|
|||||||
data: string;
|
data: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface HistogramData {
|
||||||
|
bins: number[];
|
||||||
|
min: number;
|
||||||
|
max: number;
|
||||||
|
binSize: number;
|
||||||
|
p99: number;
|
||||||
|
p95: number;
|
||||||
|
p50: number;
|
||||||
|
}
|
||||||
|
|
||||||
export interface VersionInfo {
|
export interface VersionInfo {
|
||||||
build_date: string;
|
build_date: string;
|
||||||
commit: string;
|
commit: string;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { metrics, getCapture } from "../stores/api";
|
import { metrics, getCapture } from "../stores/api";
|
||||||
|
import ActivityStats from "../components/ActivityStats.svelte";
|
||||||
import Tooltip from "../components/Tooltip.svelte";
|
import Tooltip from "../components/Tooltip.svelte";
|
||||||
import CaptureDialog from "../components/CaptureDialog.svelte";
|
import CaptureDialog from "../components/CaptureDialog.svelte";
|
||||||
import type { ReqRespCapture } from "../lib/types";
|
import type { ReqRespCapture } from "../lib/types";
|
||||||
@@ -62,34 +63,38 @@
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div class="p-2">
|
<div class="p-2">
|
||||||
<h1 class="text-2xl font-bold">Activity</h1>
|
<div class="mt-4 mb-4">
|
||||||
|
<ActivityStats />
|
||||||
|
</div>
|
||||||
|
|
||||||
{#if $metrics.length === 0}
|
<div class="card overflow-auto">
|
||||||
<div class="text-center py-8">
|
<table class="min-w-full divide-y">
|
||||||
<p class="text-gray-600">No metrics data available</p>
|
<thead class="border-gray-200 dark:border-white/10">
|
||||||
</div>
|
<tr class="text-left text-xs uppercase tracking-wider">
|
||||||
{:else}
|
<th class="px-6 py-3">ID</th>
|
||||||
<div class="card overflow-auto">
|
<th class="px-6 py-3">Time</th>
|
||||||
<table class="min-w-full divide-y">
|
<th class="px-6 py-3">Model</th>
|
||||||
<thead class="border-gray-200 dark:border-white/10">
|
<th class="px-6 py-3">
|
||||||
<tr class="text-left text-xs uppercase tracking-wider">
|
Cached <Tooltip content="prompt tokens from cache" />
|
||||||
<th class="px-6 py-3">ID</th>
|
</th>
|
||||||
<th class="px-6 py-3">Time</th>
|
<th class="px-6 py-3">
|
||||||
<th class="px-6 py-3">Model</th>
|
Prompt <Tooltip content="new prompt tokens processed" />
|
||||||
<th class="px-6 py-3">
|
</th>
|
||||||
Cached <Tooltip content="prompt tokens from cache" />
|
<th class="px-6 py-3">Generated</th>
|
||||||
</th>
|
<th class="px-6 py-3">Prompt Processing</th>
|
||||||
<th class="px-6 py-3">
|
<th class="px-6 py-3">Generation Speed</th>
|
||||||
Prompt <Tooltip content="new prompt tokens processed" />
|
<th class="px-6 py-3">Duration</th>
|
||||||
</th>
|
<th class="px-6 py-3">Capture</th>
|
||||||
<th class="px-6 py-3">Generated</th>
|
</tr>
|
||||||
<th class="px-6 py-3">Prompt Processing</th>
|
</thead>
|
||||||
<th class="px-6 py-3">Generation Speed</th>
|
<tbody class="divide-y">
|
||||||
<th class="px-6 py-3">Duration</th>
|
{#if sortedMetrics.length === 0}
|
||||||
<th class="px-6 py-3">Capture</th>
|
<tr>
|
||||||
|
<td colspan="10" class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
No activity recorded
|
||||||
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
{:else}
|
||||||
<tbody class="divide-y">
|
|
||||||
{#each sortedMetrics as metric (metric.id)}
|
{#each sortedMetrics as metric (metric.id)}
|
||||||
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
|
||||||
<td class="px-4 py-4">{metric.id + 1}</td>
|
<td class="px-4 py-4">{metric.id + 1}</td>
|
||||||
@@ -116,10 +121,10 @@
|
|||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
{/each}
|
{/each}
|
||||||
</tbody>
|
{/if}
|
||||||
</table>
|
</tbody>
|
||||||
</div>
|
</table>
|
||||||
{/if}
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} />
|
<CaptureDialog capture={selectedCapture} open={dialogOpen} onclose={closeDialog} />
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import { isNarrow } from "../stores/theme";
|
import { isNarrow } from "../stores/theme";
|
||||||
import { upstreamLogs } from "../stores/api";
|
import { upstreamLogs } from "../stores/api";
|
||||||
import ModelsPanel from "../components/ModelsPanel.svelte";
|
import ModelsPanel from "../components/ModelsPanel.svelte";
|
||||||
import StatsPanel from "../components/StatsPanel.svelte";
|
|
||||||
import LogPanel from "../components/LogPanel.svelte";
|
import LogPanel from "../components/LogPanel.svelte";
|
||||||
import ResizablePanels from "../components/ResizablePanels.svelte";
|
import ResizablePanels from "../components/ResizablePanels.svelte";
|
||||||
|
|
||||||
@@ -14,13 +13,6 @@
|
|||||||
<ModelsPanel />
|
<ModelsPanel />
|
||||||
{/snippet}
|
{/snippet}
|
||||||
{#snippet rightPanel()}
|
{#snippet rightPanel()}
|
||||||
<div class="flex flex-col h-full space-y-4">
|
<LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
|
||||||
{#if direction === "horizontal"}
|
|
||||||
<StatsPanel />
|
|
||||||
{/if}
|
|
||||||
<div class="flex-1 min-h-0">
|
|
||||||
<LogPanel id="modelsupstream" title="Upstream Logs" logData={$upstreamLogs} />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/snippet}
|
{/snippet}
|
||||||
</ResizablePanels>
|
</ResizablePanels>
|
||||||
|
|||||||
Reference in New Issue
Block a user