Compare commits

...

13 Commits

Author SHA1 Message Date
steve 0292c90ca1 ci: copy ui-svelte/.npmrc before npm ci in fork-cuda build
Build CUDA image (fork) / build (push) Successful in 12m49s
npm ci ran without .npmrc (legacy-peer-deps=true), failing on the
tailwind/vite peer dependency conflict. Copy .npmrc with the manifest.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-28 12:56:21 -04:00
steve 617c7dc6b9 ci: add Gitea workflow to build fork CUDA image
Build CUDA image (fork) / build (push) Failing after 2m23s
Add a Gitea Actions workflow and multi-stage Containerfile that build
this fork's llama-swap (serial scheduler + embedded Svelte UI) from
source and layer it on a pinned llama.cpp CUDA server base, then push to
the Gitea container registry as v230-cuda-b9821.

- docker/fork-cuda.Containerfile: node UI -> go build -> cuda runtime,
  runs as root to match the upstream non-suffixed image
- .gitea/workflows/build-cuda-image.yml: workflow_dispatch (version +
  llama.cpp build inputs) and push-on-build-files; logs in with
  REGISTRY_USER/REGISTRY_PASSWORD

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-28 12:48:48 -04:00
steve 542b79dacf internal/router/scheduler: add serial scheduler, default on this fork
Validate JSON Schema / validate-schema (push) Successful in 9m53s
Linux CI / run-tests (push) Failing after 15m57s
Windows CI / run-tests (push) Has been cancelled
Add a strict one-model-at-a-time scheduler. Requests run in exact
arrival order; at most one runs at a time; switching to a different
model evicts every other running model first so a single model occupies
memory at a time. Unlike fifo it never reorders or batches same-model
requests, and it ignores group/matrix co-residency entirely, making the
single-model guarantee a property of the scheduler rather than the config.

- new Serial scheduler implementing the Scheduler interface
- register "serial" in scheduler.New; default routing.scheduler.use to
  "serial" at config load (fifo still selectable for upstream behavior)
- update config schema, example config, and config defaults tests

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-28 12:17:32 -04:00
Benson Wong 0a25b3bd31 AGENTS.md: small tweaks 2026-06-25 20:31:48 -07:00
Benson Wong 32bc781326 internal/config,watcher: add -config-dir (#873)
Over time the llama-swap configuration file can get really long and
challenging to work with. The -config-dir flag is used for a directory
of configuration YAML fragments.

These fragments are merged together and into a full configuration and
tested for validity. All previous configuration functionality remains
unchanged.
2026-06-24 20:48:51 -07:00
Benson Wong 316ad63f76 config,server: add upstream.ignorePaths (#869)
Add upstream.ignorePaths config to prevent model swaps for static-asset
requests made through the /upstream/<model>/<path> passthrough endpoint.

- add UpstreamConfig with compiled *regexp.Regexp slice; invalid regex
returns an error at load time
- apply a default pattern matching common static-asset suffixes
(.js/.json/.css/.png/.gif/.jpg/.jpeg/.ico/.txt) when unset
- in handleUpstream, return 409 Conflict when a path matches and the
local model is not already loaded; peer and already-loaded models fall
through to normal dispatch
- update config-schema.json and config.example.yaml

Updates discussion: #868
2026-06-21 13:49:53 -07:00
g2mt e37077a963 feat: hide performance menu item if disabled (#832)
Hide the Performance UI item of the navigation bar if its disabled.
2026-06-21 13:38:29 -07:00
Benson Wong eff9b60434 server: capture failed (non-200) LLM requests (#862)
Store a request/response capture for non-200 responses so failed
requests can be inspected in the activity log's Capture dialog, matching
the existing behavior for successful requests.

- extract storeCapture/decodeResponseBody helpers to share capture logic
between the success and non-200 paths
- record non-200 bodies (decompressed) so error details are viewable
- the activity UI already gates the View button on has_capture, so it
now appears for failed requests with no UI changes
- add tests for capturing failed requests and the disabled-captures case

closes #766
2026-06-20 11:50:35 -07:00
Wojciech 9bcddad91b internal/server,ui: add new Acitivty page column - Drafted (#859)
Add draft metrics to activity log
2026-06-18 20:55:02 -07:00
Benson Wong a15e47922c proxy: meter /upstream requests via metrics middleware (#858)
Wrap /upstream/{upstreamPath...} in the metrics middleware so activity
log entries are recorded for model-dispatched endpoints accessed through
the upstream passthrough.

- Move findModelInPath to shared.FindModelInPath and reuse it in
handleUpstream, the log monitor lookup, and FetchContext.
- Extend FetchContext to resolve the model from /upstream/<model>/...
paths without consuming the request body.
- Add isMetricsRecordPath to limit recording to the model-dispatched
endpoints that produce token usage/timings.
- Add tests for upstream metrics recording and FetchContext upstream
path resolution.

Fixes #855
2026-06-17 17:38:52 -07:00
George 0ab214d1c8 perf: add vendor-agnostic GPU monitoring for Windows (experimental) (#779)
Add GPU monitoring support for AMD and Intel GPUs on Windows using
D3DKMT (DirectX) and PDH performance counters.

- Add PDH-based GPU utilization via \GPU Engine(*)\Utilization
Percentage counter, summing all engine types per adapter (3D, Compute,
Copy, Video).
- Add D3DKMT bindings for adapter enumeration, memory segments, and
adapter perf data.
- Use PDH as primary utilization source (works on all vendors), with
D3DKMT RunningTime as fallback for systems without PDH counters.
- Prefer nvidia-smi when available, fall back to D3DKMT + PDH for
AMD/Intel.
- Backend priority: nvidia-smi -> D3DKMT + PDH -> ErrNoGpuTool.

Verified on AMD 7900XTX GPU with llama.cpp Vulkan & ROCm backend: GPU
utilization correctly shows ~99% during inference, ~0-2% when idle.

---

LLM disclosure: GLM 5.1 & Kimi K2.6 have been used extensively during
exploration and coding to the point that the LLM's wrote over 3/4 of the
code, and I have done additional verification myself.
As such, it should be considered experimental.
Additional verification is needed.

I have tested it on my 7900XTX system with Windows 11, and it works
correctly, but as I only have this one rig, I cannot verify it
everywhere.
2026-06-16 21:49:09 -07:00
Benson Wong d07b063ab6 internal/server,shared: support request metadata (#850)
- add support for http handlers in the request chain to append metadata
to the request
- metrics middleware will include metadata in the activity log 
- update Activity UI to support metadata, drag sort columns
- update Activity UI capture dialog to use more screen space

Updates #834
2026-06-16 21:44:55 -07:00
Benson Wong 826210dac9 .coderabbit.yaml: disable unit_tests 2026-06-16 10:10:17 -07:00
48 changed files with 4779 additions and 910 deletions
+2
View File
@@ -15,6 +15,8 @@ reviews:
auto_review: auto_review:
enabled: false enabled: false
drafts: false drafts: false
unit_tests:
enabled: false
chat: chat:
auto_reply: true auto_reply: true
issue_enrichment: issue_enrichment:
+76
View File
@@ -0,0 +1,76 @@
name: Build CUDA image (fork)
# Builds this fork's llama-swap (serial scheduler + embedded UI) from source and
# layers it on a pinned llama.cpp CUDA server base, then pushes to the Gitea
# container registry, e.g. gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
#
# Requires repo secrets: REGISTRY_USER, REGISTRY_PASSWORD (push to the registry).
on:
workflow_dispatch:
inputs:
llama_swap_version:
description: "llama-swap version label (image tag prefix)"
required: false
default: "v230"
llamacpp_build:
description: "llama.cpp CUDA server build (base image tag suffix)"
required: false
default: "b9821"
# Building the build definition itself kicks off a fresh image.
push:
branches: [main]
paths:
- ".gitea/workflows/build-cuda-image.yml"
- "docker/fork-cuda.Containerfile"
env:
REGISTRY: gitea.stevedudenhoeffer.com
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Compute image metadata
id: meta
run: |
LS_VER="${{ inputs.llama_swap_version || 'v230' }}"
LCPP="${{ inputs.llamacpp_build || 'b9821' }}"
{
echo "image=${REGISTRY}/${{ github.repository }}"
echo "tag=${LS_VER}-cuda-${LCPP}"
echo "base_tag=server-cuda-${LCPP}"
echo "ls_version=${LS_VER}"
echo "build_date=$(date -u +%Y-%m-%dT%H:%M:%SZ)"
} >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Gitea registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ secrets.REGISTRY_USER }}
password: ${{ secrets.REGISTRY_PASSWORD }}
- name: Build and push
uses: docker/build-push-action@v6
with:
context: .
file: docker/fork-cuda.Containerfile
push: true
provenance: false
build-args: |
BASE_TAG=${{ steps.meta.outputs.base_tag }}
LS_VERSION=${{ steps.meta.outputs.ls_version }}
GIT_HASH=${{ github.sha }}
BUILD_DATE=${{ steps.meta.outputs.build_date }}
tags: ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}
- name: Summary
run: |
echo "Pushed ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}" >> "$GITHUB_STEP_SUMMARY"
+4 -6
View File
@@ -5,16 +5,14 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
## Tech stack ## Tech stack
- golang - golang
- typescript, vite and svelt5 for UI (located in ui/) - typescript, vite and svelte 5 for UI (located in ui-svelte/)
## Workflow Tasks ## Workflow Tasks
- when summarizing changes only include details that require further action - when summarizing changes only include details that require further action
- just say "Done." when there is no further action
- use the github CLI `gh` to create pull requests and work with github
- Rules for creating pull requests: - Rules for creating pull requests:
- keep them short and focused on changes. - keep them short and focused on changes
- never include a test plan - skip the test plan
- write the summary using the same style rules as commit message - write the summary using the same style rules as commit message
## Testing ## Testing
@@ -30,7 +28,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
### Commit message example format: ### Commit message example format:
``` ```
proxy: add new feature internal/server: add new feature
Add new feature that implements functionality X and Y. Add new feature that implements functionality X and Y.
+21 -2
View File
@@ -572,6 +572,24 @@
"default": {}, "default": {},
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap." "description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
}, },
"upstream": {
"type": "object",
"description": "Controls behaviour of the /upstream passthrough endpoint. Recommended to only use in special use cases; leaving it as the default will typically be the best experience.",
"properties": {
"ignorePaths": {
"type": "array",
"items": {
"type": "string"
},
"default": [
".*\\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$"
],
"description": "List of RE2 compatible regular expressions. Any request to a path matching any of the regular expressions will be ignored and not trigger a swap. When not specified, defaults to a pattern matching common static-asset suffixes (.js, .json, .css, .png, .gif, .jpg, .jpeg, .ico, .txt)."
}
},
"additionalProperties": false,
"default": {}
},
"routing": { "routing": {
"type": "object", "type": "object",
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.", "description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
@@ -583,10 +601,11 @@
"use": { "use": {
"type": "string", "type": "string",
"enum": [ "enum": [
"serial",
"fifo" "fifo"
], ],
"default": "fifo", "default": "serial",
"description": "Scheduler to use. Only 'fifo' is currently supported." "description": "Scheduler to use. 'serial' (default on this fork): strict one-model-at-a-time, requests run in exact arrival order, switching models evicts every other model first. 'fifo': throughput-oriented, batches same-model requests and allows parallel/co-resident models."
}, },
"settings": { "settings": {
"type": "object", "type": "object",
+25 -3
View File
@@ -134,6 +134,18 @@ apiKeys:
- "${env.API_KEY_1}" - "${env.API_KEY_1}"
- "${env.API_KEY_2}" - "${env.API_KEY_2}"
# upstream: controls behaviour of the /upstream passthrough endpoint
# - optional, default: empty dictionary
# - recommended to only use in special use cases. Leaving it as the
# default will typically be the best experience
upstream:
# ignorePaths: list of RE2 compatible regular expressions
# - default: (see below)
# - any request to a path matching any of the regular expressions
# will be ignored and not trigger a swap
ignorePaths:
- '.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$'
# models: a dictionary of model configurations # models: a dictionary of model configurations
# - required # - required
# - each key is the model's ID, used in API requests # - each key is the model's ID, used in API requests
@@ -544,11 +556,21 @@ routing:
# expands to: [L] # expands to: [L]
full: "L" full: "L"
# scheduler: how queued requests are ordered. # scheduler: how queued requests are ordered and run.
# The default and only valid scheduler is "fifo" # - optional, default on this fork: "serial"
# - valid values:
# - "serial": strict one-model-at-a-time. Requests run in exact arrival
# order; only one request runs at a time; switching to a different model
# evicts every other running model first so a single model occupies memory
# at a time. This ignores group/matrix co-residency entirely. The "fifo"
# settings below (priority) do not apply.
# - "fifo": throughput-oriented. Same-model requests are batched to reduce
# swaps and a model serves up to its concurrencyLimit in parallel; models
# in non-exclusive groups can run concurrently. Requests may be reordered.
scheduler: scheduler:
use: fifo use: serial
settings: settings:
# fifo settings only apply when use: fifo
fifo: fifo:
# priority: a dictionary of model ID -> priority # priority: a dictionary of model ID -> priority
# - optional, default: empty dictionary # - optional, default: empty dictionary
+74
View File
@@ -0,0 +1,74 @@
# Build a CUDA llama-swap image FROM THIS FORK's source (includes the serial
# scheduler) and layer it on a pinned llama.cpp CUDA server base. Produces e.g.:
# gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
#
# BASE_TAG selects the llama.cpp CUDA runtime + llama-server build, e.g.
# "server-cuda-b9821". The llama-swap binary (with the embedded Svelte UI) is
# compiled from the repo at build time, so no GitHub release is required.
#
# Build context is the repo root:
# docker build -f docker/fork-cuda.Containerfile \
# --build-arg BASE_TAG=server-cuda-b9821 -t llama-swap:v230-cuda-b9821 .
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
ARG BASE_TAG=server-cuda-b9821
# ---- Stage 1: build the Svelte UI (embedded into the binary) ----
FROM node:22-bookworm-slim AS ui
WORKDIR /src/ui-svelte
# Install deps first for layer caching. .npmrc carries legacy-peer-deps=true,
# which the project relies on (tailwind/vite peer ranges), so copy it before
# npm ci or the strict resolver fails with ERESOLVE.
COPY ui-svelte/package.json ui-svelte/package-lock.json ui-svelte/.npmrc ./
RUN npm ci
COPY ui-svelte/ ./
# `npm run build` is `vite build --emptyOutDir`; vite.config.ts writes to
# ../internal/server/ui_dist, which //go:embed picks up in the next stage.
RUN mkdir -p /src/internal/server && npm run build
# ---- Stage 2: build the llama-swap binary with the embedded UI ----
FROM golang:1.26-bookworm AS build
WORKDIR /src
# Cache modules independently of source churn.
COPY go.mod go.sum ./
RUN go mod download
COPY . .
# Overlay the freshly built UI so //go:embed ui_dist ships the real assets
# instead of the committed placeholder.
COPY --from=ui /src/internal/server/ui_dist/ ./internal/server/ui_dist/
ARG LS_VERSION=v230
ARG GIT_HASH=unknown
ARG BUILD_DATE=unknown
RUN CGO_ENABLED=0 GOOS=linux go build \
-ldflags="-X main.version=${LS_VERSION} -X main.commit=${GIT_HASH} -X main.date=${BUILD_DATE}" \
-o /out/llama-swap .
# ---- Stage 3: runtime image on the pinned llama.cpp CUDA base ----
FROM ${BASE_IMAGE}:${BASE_TAG}
# Run as root by default to match the upstream `vNNN-cuda-bNNNN` (non-suffixed)
# image that ragnaros pulls today: it needs root to reach the mounted docker
# socket for container-backed models (sd-server). Override UID/GID at build time
# for a non-root variant.
ARG UID=0
ARG GID=0
ARG USER_HOME=/root
ENV HOME=$USER_HOME
RUN set -eux; \
if [ "$UID" -ne 0 ]; then \
if [ "$GID" -ne 0 ]; then groupadd --system --gid "$GID" app; fi; \
useradd --system --uid "$UID" --gid "$GID" --home "$USER_HOME" app; \
fi; \
mkdir --parents "$HOME" /app; \
chown --recursive "$UID:$GID" "$HOME" /app
COPY --from=build --chown=$UID:$GID /out/llama-swap /app/llama-swap
COPY --chown=$UID:$GID docker/config.example.yaml /app/config.yaml
USER $UID:$GID
WORKDIR /app
ENV PATH="/app:${PATH}"
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
+57
View File
@@ -0,0 +1,57 @@
package config
import (
"fmt"
"runtime"
"strings"
"github.com/billziss-gh/golib/shlex"
)
func SanitizeCommand(cmdStr string) ([]string, error) {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments
var args []string
if runtime.GOOS == "windows" {
args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
}
// Ensure the command is not empty
if len(args) == 0 {
return nil, fmt.Errorf("empty command")
}
return args, nil
}
func StripComments(cmdStr string) string {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
+3 -661
View File
@@ -2,16 +2,9 @@ package config
import ( import (
"fmt" "fmt"
"io"
"net/url"
"os" "os"
"regexp"
"runtime"
"sort" "sort"
"strings"
"time"
"github.com/billziss-gh/golib/shlex"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -85,12 +78,6 @@ type GroupConfig struct {
Members []string `yaml:"members"` Members []string `yaml:"members"`
} }
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
)
// set default values for GroupConfig // set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig type rawGroupConfig GroupConfig
@@ -163,6 +150,9 @@ type Config struct {
// support remote peers, see issue #433, #296 // support remote peers, see issue #433, #296
Peers PeerDictionaryConfig `yaml:"peers"` Peers PeerDictionaryConfig `yaml:"peers"`
// upstream controls behaviour of the /upstream passthrough endpoint
Upstream UpstreamConfig `yaml:"upstream"`
} }
// RoutingConfig is the canonical, normalized routing/scheduling configuration. // RoutingConfig is the canonical, normalized routing/scheduling configuration.
@@ -221,424 +211,6 @@ func LoadConfig(path string) (Config, error) {
return LoadConfigFromReader(file) return LoadConfigFromReader(file)
} }
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
}
yamlStr := string(data)
// Phase 1: Substitute all ${env.VAR} macros at string level
// This is safe because env values are simple strings without YAML formatting
yamlStr, err = substituteEnvMacros(yamlStr)
if err != nil {
return Config{}, err
}
// Unmarshal into full Config with defaults
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
GlobalTTL: 0,
}
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
config.HealthCheckTimeout = 15
}
// Apply defaults for performance config when section is missing
if config.Performance.Every == 0 {
config.Performance.Every = 5 * time.Second
}
if err = config.Performance.Validate(); err != nil {
return Config{}, fmt.Errorf("performance: %w", err)
}
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
if config.GlobalTTL < 0 {
return Config{}, fmt.Errorf("globalTTL must be >= 0")
}
switch config.LogToStdout {
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
default:
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName
}
}
// Validate global macros
for _, macro := range config.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, err
}
}
// Get and sort all model IDs for consistent port assignment
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds)
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
// Strip comments from command fields
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// set model TTL to globalTTL it is the default value
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
modelConfig.UnloadAfter = config.GlobalTTL
}
if modelConfig.UnloadAfter < 0 {
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
}
// Validate model macros
for _, macro := range modelConfig.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
}
}
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
mergedMacros = append(mergedMacros, config.Macros...)
// Add model macros (override globals with same name)
for _, entry := range modelConfig.Macros {
found := false
for i, existing := range mergedMacros {
if existing.Name == entry.Name {
mergedMacros[i] = entry
found = true
break
}
}
if !found {
mergedMacros = append(mergedMacros, entry)
}
}
// Substitute remaining macros in model fields (LIFO order)
for i := len(mergedMacros) - 1; i >= 0; i-- {
entry := mergedMacros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
// Substitute macros in SetParamsByID keys and values
if len(modelConfig.Filters.SetParamsByID) > 0 {
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
for key, paramMap := range modelConfig.Filters.SetParamsByID {
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
}
newParamMap, ok := newValAny.(map[string]any)
if !ok {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
}
newSetParamsByID[newKey] = newParamMap
}
modelConfig.Filters.SetParamsByID = newSetParamsByID
}
// Substitute in metadata (type-preserving)
if len(modelConfig.Metadata) > 0 {
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
}
// Handle PORT macro - only allocate if cmd uses it
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
if cmdHasPort || proxyHasPort {
if !cmdHasPort && proxyHasPort {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
macroSlug := "${PORT}"
macroStr := fmt.Sprintf("%v", nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
if len(modelConfig.Metadata) > 0 {
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
nextPort++
}
// Validate no unknown macros remain
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
"filters.stripParams": modelConfig.Filters.StripParams,
"name": modelConfig.Name,
"description": modelConfig.Description,
}
for fieldName, fieldValue := range fieldMap {
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
continue // replaced at runtime
}
if macroName == "PORT" || macroName == "MODEL_ID" {
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
}
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
if len(modelConfig.Metadata) > 0 {
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
return Config{}, err
}
}
if err = modelConfig.Capabilities.Validate(); err != nil {
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
}
// Validate SetParamsByID keys and values
for key, paramMap := range modelConfig.Filters.SetParamsByID {
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
}
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
return Config{}, err
}
}
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
for key := range modelConfig.Filters.SetParamsByID {
if key == modelId {
continue
}
if _, exists := config.Models[key]; exists {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
}
if existingModel, exists := config.aliases[key]; exists {
if existingModel != modelId {
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
}
continue // already registered as explicit alias for this model
}
config.aliases[key] = modelId
modelConfig.Aliases = append(modelConfig.Aliases, key)
}
if _, err := url.Parse(modelConfig.Proxy); err != nil {
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
}
if modelConfig.SendLoadingState == nil {
v := config.SendLoadingState
modelConfig.SendLoadingState = &v
}
config.Models[modelId] = modelConfig
}
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
// the new `routing.router` block are mutually exclusive: a config may use
// either style, never both.
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
rtr := config.Routing.Router
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
if hasTopLevel && hasRouting {
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
}
if !hasTopLevel {
// Both groups and matrix may be defined under routing.router.settings;
// routing.router.use selects which one is active, so there is no conflict.
rs := config.Routing.Router.Settings
switch config.Routing.Router.Use {
case "matrix":
if rs.Matrix == nil {
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
}
config.Matrix = rs.Matrix
case "group", "":
config.Groups = rs.Groups
default:
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
}
}
// groups XOR matrix
if config.Matrix != nil && len(config.Groups) > 0 {
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
}
if config.Matrix != nil {
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
if err != nil {
return Config{}, fmt.Errorf("matrix: %w", err)
}
config.Matrix.ExpandedSets = expandedSets
} else {
config = AddDefaultGroupToConfig(config)
// Validate group members
memberUsage := make(map[string]string)
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
}
// Build the canonical Config.Routing from the effective result. Both legacy
// and new-style configs converge here. The Matrix pointer is shared so
// ExpandedSets stays in one place.
if config.Matrix != nil {
config.Routing.Router.Use = "matrix"
} else {
config.Routing.Router.Use = "group"
}
config.Routing.Router.Settings.Matrix = config.Matrix
config.Routing.Router.Settings.Groups = config.Groups
if config.Routing.Scheduler.Use == "" {
config.Routing.Scheduler.Use = "fifo"
}
if config.Routing.Scheduler.Use != "fifo" {
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
}
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
if _, found := config.RealModelName(modelID); !found {
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
}
}
// Clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string
for _, modelID := range config.Hooks.OnStartup.Preload {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
if real, found := config.RealModelName(modelID); found {
toPreload = append(toPreload, real)
}
}
config.Hooks.OnStartup.Preload = toPreload
}
// Validate API keys (env macros already substituted at string level)
for i, apikey := range config.RequiredAPIKeys {
if apikey == "" {
return Config{}, fmt.Errorf("empty api key found in apiKeys")
}
if strings.Contains(apikey, " ") {
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
}
config.RequiredAPIKeys[i] = apikey
}
// Process peers with global macro substitution
for peerName, peerConfig := range config.Peers {
// Substitute global macros (LIFO order)
for i := len(config.Macros) - 1; i >= 0; i-- {
entry := config.Macros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
// Substitute in setParams (type-preserving)
if len(peerConfig.Filters.SetParams) > 0 {
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
}
peerConfig.Filters.SetParams = result.(map[string]any)
}
}
// Validate no unknown macros remain
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
}
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
}
if len(peerConfig.Filters.SetParams) > 0 {
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
return Config{}, err
}
}
config.Peers[peerName] = peerConfig
}
return config, nil
}
// rewrites the yaml to include a default group with any orphaned models // rewrites the yaml to include a default group with any orphaned models
func AddDefaultGroupToConfig(config Config) Config { func AddDefaultGroupToConfig(config Config) Config {
@@ -683,233 +255,3 @@ func AddDefaultGroupToConfig(config Config) Config {
return config return config
} }
func SanitizeCommand(cmdStr string) ([]string, error) {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
// Handle trailing backslashes by replacing with space
if strings.HasSuffix(trimmed, "\\") {
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
} else {
cleanedLines = append(cleanedLines, line)
}
}
// put it back together
cmdStr = strings.Join(cleanedLines, "\n")
// Split the command into arguments
var args []string
if runtime.GOOS == "windows" {
args = shlex.Windows.Split(cmdStr)
} else {
args = shlex.Posix.Split(cmdStr)
}
// Ensure the command is not empty
if len(args) == 0 {
return nil, fmt.Errorf("empty command")
}
return args, nil
}
func StripComments(cmdStr string) string {
var cleanedLines []string
for _, line := range strings.Split(cmdStr, "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") {
continue
}
cleanedLines = append(cleanedLines, line)
}
return strings.Join(cleanedLines, "\n")
}
// validateMacro validates macro name and value constraints
func validateMacro(name string, value any) error {
if len(name) >= 64 {
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
}
if !macroNameRegex.MatchString(name) {
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
}
// Validate that value is a scalar type
switch v := value.(type) {
case string:
// Check for self-reference
macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(v, macroSlug) {
return fmt.Errorf("macro '%s' contains self-reference", name)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
// These types are allowed
default:
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
}
switch name {
case "PORT", "MODEL_ID":
return fmt.Errorf("macro name '%s' is reserved", name)
}
return nil
}
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
func validateNestedForUnknownMacros(value any, context string) error {
switch v := value.(type) {
case string:
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
for _, match := range matches {
macroName := match[1]
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
}
// Check for unsubstituted env macros
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
for _, match := range envMatches {
varName := match[1]
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
}
return nil
case map[string]any:
for _, val := range v {
if err := validateNestedForUnknownMacros(val, context); err != nil {
return err
}
}
return nil
case []any:
for _, val := range v {
if err := validateNestedForUnknownMacros(val, context); err != nil {
return err
}
}
return nil
default:
// Scalar types don't contain macros
return nil
}
}
// substituteMacroInValue recursively substitutes a single macro in a value structure
// This is called once per macro, allowing LIFO substitution order
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
macroSlug := fmt.Sprintf("${%s}", macroName)
macroStr := fmt.Sprintf("%v", macroValue)
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if v == macroSlug {
return macroValue, nil
}
// Handle string interpolation
if strings.Contains(v, macroSlug) {
return strings.ReplaceAll(v, macroSlug, macroStr), nil
}
return v, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newMap[key] = newVal
}
return newMap, nil
case []any:
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newSlice[i] = newVal
}
return newSlice, nil
default:
// Return scalar types as-is
return value, nil
}
}
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
// Returns error if any referenced env var is not set or contains invalid characters.
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
// (which strips comments) and only checking the comment-free version for macros.
func substituteEnvMacros(s string) (string, error) {
// Unmarshal and remarshal to strip YAML comments
var raw any
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
// If YAML is invalid, fall back to scanning the original string
// so the user gets the env var error rather than a confusing YAML parse error
return substituteEnvMacrosInString(s, s)
}
clean, err := yaml.Marshal(raw)
if err != nil {
return substituteEnvMacrosInString(s, s)
}
return substituteEnvMacrosInString(s, string(clean))
}
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
// them in target. This separation allows scanning comment-free YAML while
// substituting in the original string.
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
result := target
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
for _, match := range matches {
fullMatch := match[0] // ${env.VAR_NAME}
varName := match[1] // VAR_NAME
value, exists := os.LookupEnv(varName)
if !exists {
return "", fmt.Errorf("environment variable '%s' is not set", varName)
}
// Sanitize the value for safe YAML substitution
value, err := sanitizeEnvValueForYAML(value, varName)
if err != nil {
return "", err
}
result = strings.ReplaceAll(result, fullMatch, value)
}
return result, nil
}
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
// for compatibility with double-quoted YAML strings.
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
// Reject values that would break YAML structure regardless of quoting context
if strings.ContainsAny(value, "\n\r\x00") {
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
}
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
// In double-quoted contexts, they are interpreted correctly.
value = strings.ReplaceAll(value, `\`, `\\`)
value = strings.ReplaceAll(value, `"`, `\"`)
return value, nil
}
+4 -1
View File
@@ -266,6 +266,9 @@ groups:
"mthree": "model3", "mthree": "model3",
}, },
Groups: expectedGroups, Groups: expectedGroups,
Upstream: UpstreamConfig{
IgnorePaths: DefaultUpstreamIgnorePaths(),
},
Routing: RoutingConfig{ Routing: RoutingConfig{
Router: RouterConfig{ Router: RouterConfig{
Use: "group", Use: "group",
@@ -274,7 +277,7 @@ groups:
}, },
}, },
Scheduler: SchedulerConfig{ Scheduler: SchedulerConfig{
Use: "fifo", Use: "serial",
}, },
}, },
} }
+11 -6
View File
@@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) {
{ {
name: "blank spaces only", name: "blank spaces only",
content: `apiKeys: [" "]`, content: `apiKeys: [" "]`,
expectedErr: "api key cannot contain spaces: ` `", expectedErr: "apiKeys[0]: api key cannot contain spaces",
}, },
{ {
name: "contains leading space", name: "contains leading space",
content: `apiKeys: [" key123"]`, content: `apiKeys: [" key123"]`,
expectedErr: "api key cannot contain spaces: ` key123`", expectedErr: "apiKeys[0]: api key cannot contain spaces",
}, },
{ {
name: "contains trailing space", name: "contains trailing space",
content: `apiKeys: ["key123 "]`, content: `apiKeys: ["key123 "]`,
expectedErr: "api key cannot contain spaces: `key123 `", expectedErr: "apiKeys[0]: api key cannot contain spaces",
}, },
{ {
name: "contains middle space", name: "contains middle space",
content: `apiKeys: ["key 123"]`, content: `apiKeys: ["key 123"]`,
expectedErr: "api key cannot contain spaces: `key 123`", expectedErr: "apiKeys[0]: api key cannot contain spaces",
},
{
name: "space in second key reports correct index",
content: `apiKeys: ["valid-key", "bad key"]`,
expectedErr: "apiKeys[1]: api key cannot contain spaces",
}, },
{ {
name: "empty in list with valid keys", name: "empty in list with valid keys",
@@ -1567,7 +1572,7 @@ groups:
assert.Equal(t, "group", cfg.Routing.Router.Use) assert.Equal(t, "group", cfg.Routing.Router.Use)
// default group injected for orphaned models (none here) still leaves g1 // default group injected for orphaned models (none here) still leaves g1
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1") assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use) assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
} }
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) { func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
@@ -1626,7 +1631,7 @@ func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels)) cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "group", cfg.Routing.Router.Use) assert.Equal(t, "group", cfg.Routing.Router.Use)
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use) assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
} }
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) { func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
+4 -1
View File
@@ -255,6 +255,9 @@ groups:
"mthree": "model3", "mthree": "model3",
}, },
Groups: expectedGroups, Groups: expectedGroups,
Upstream: UpstreamConfig{
IgnorePaths: DefaultUpstreamIgnorePaths(),
},
Routing: RoutingConfig{ Routing: RoutingConfig{
Router: RouterConfig{ Router: RouterConfig{
Use: "group", Use: "group",
@@ -263,7 +266,7 @@ groups:
}, },
}, },
Scheduler: SchedulerConfig{ Scheduler: SchedulerConfig{
Use: "fifo", Use: "serial",
}, },
}, },
} }
+441
View File
@@ -0,0 +1,441 @@
package config
import (
"fmt"
"io"
"net/url"
"sort"
"strings"
"time"
"gopkg.in/yaml.v3"
)
func LoadConfigFromReader(r io.Reader) (Config, error) {
data, err := io.ReadAll(r)
if err != nil {
return Config{}, err
}
yamlStr := string(data)
// Phase 1: Substitute all ${env.VAR} macros at string level
// This is safe because env values are simple strings without YAML formatting
yamlStr, err = substituteEnvMacros(yamlStr)
if err != nil {
return Config{}, err
}
// Unmarshal into full Config with defaults
config := Config{
HealthCheckTimeout: 120,
StartPort: 5800,
LogLevel: "info",
LogTimeFormat: "",
LogToStdout: LogToStdoutProxy,
MetricsMaxInMemory: 1000,
CaptureBuffer: 5,
GlobalTTL: 0,
}
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
return Config{}, err
}
if config.HealthCheckTimeout < 15 {
config.HealthCheckTimeout = 15
}
// Apply defaults for performance config when section is missing
if config.Performance.Every == 0 {
config.Performance.Every = 5 * time.Second
}
if err = config.Performance.Validate(); err != nil {
return Config{}, fmt.Errorf("performance: %w", err)
}
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
if config.GlobalTTL < 0 {
return Config{}, fmt.Errorf("globalTTL must be >= 0")
}
// Apply default for upstream.ignorePaths when not specified. The default
// matches common static-asset suffixes so they do not trigger a swap.
if len(config.Upstream.IgnorePaths) == 0 {
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
}
switch config.LogToStdout {
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
default:
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
}
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
if _, found := config.aliases[alias]; found {
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
}
config.aliases[alias] = modelName
}
}
// Validate global macros
for _, macro := range config.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, err
}
}
// Get and sort all model IDs for consistent port assignment
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
modelIds = append(modelIds, modelId)
}
sort.Strings(modelIds)
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
// Strip comments from command fields
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// set model TTL to globalTTL it is the default value
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
modelConfig.UnloadAfter = config.GlobalTTL
}
if modelConfig.UnloadAfter < 0 {
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
}
// Validate model macros
for _, macro := range modelConfig.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
}
}
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
mergedMacros = append(mergedMacros, config.Macros...)
// Add model macros (override globals with same name)
for _, entry := range modelConfig.Macros {
found := false
for i, existing := range mergedMacros {
if existing.Name == entry.Name {
mergedMacros[i] = entry
found = true
break
}
}
if !found {
mergedMacros = append(mergedMacros, entry)
}
}
// Substitute remaining macros in model fields (LIFO order)
for i := len(mergedMacros) - 1; i >= 0; i-- {
entry := mergedMacros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
// Substitute macros in SetParamsByID keys and values
if len(modelConfig.Filters.SetParamsByID) > 0 {
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
for key, paramMap := range modelConfig.Filters.SetParamsByID {
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
}
newParamMap, ok := newValAny.(map[string]any)
if !ok {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
}
newSetParamsByID[newKey] = newParamMap
}
modelConfig.Filters.SetParamsByID = newSetParamsByID
}
// Substitute in metadata (type-preserving)
if len(modelConfig.Metadata) > 0 {
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
}
// Handle PORT macro - only allocate if cmd uses it
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
if cmdHasPort || proxyHasPort {
if !cmdHasPort && proxyHasPort {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
macroSlug := "${PORT}"
macroStr := fmt.Sprintf("%v", nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
if len(modelConfig.Metadata) > 0 {
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
nextPort++
}
// Validate no unknown macros remain
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
"filters.stripParams": modelConfig.Filters.StripParams,
"name": modelConfig.Name,
"description": modelConfig.Description,
}
for fieldName, fieldValue := range fieldMap {
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
continue // replaced at runtime
}
if macroName == "PORT" || macroName == "MODEL_ID" {
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
}
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
if len(modelConfig.Metadata) > 0 {
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
return Config{}, err
}
}
if err = modelConfig.Capabilities.Validate(); err != nil {
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
}
// Validate SetParamsByID keys and values
for key, paramMap := range modelConfig.Filters.SetParamsByID {
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
}
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
return Config{}, err
}
}
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
for key := range modelConfig.Filters.SetParamsByID {
if key == modelId {
continue
}
if _, exists := config.Models[key]; exists {
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
}
if existingModel, exists := config.aliases[key]; exists {
if existingModel != modelId {
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
}
continue // already registered as explicit alias for this model
}
config.aliases[key] = modelId
modelConfig.Aliases = append(modelConfig.Aliases, key)
}
if _, err := url.Parse(modelConfig.Proxy); err != nil {
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
}
if modelConfig.SendLoadingState == nil {
v := config.SendLoadingState
modelConfig.SendLoadingState = &v
}
config.Models[modelId] = modelConfig
}
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
// the new `routing.router` block are mutually exclusive: a config may use
// either style, never both.
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
rtr := config.Routing.Router
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
if hasTopLevel && hasRouting {
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
}
if !hasTopLevel {
// Both groups and matrix may be defined under routing.router.settings;
// routing.router.use selects which one is active, so there is no conflict.
rs := config.Routing.Router.Settings
switch config.Routing.Router.Use {
case "matrix":
if rs.Matrix == nil {
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
}
config.Matrix = rs.Matrix
case "group", "":
config.Groups = rs.Groups
default:
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
}
}
// groups XOR matrix
if config.Matrix != nil && len(config.Groups) > 0 {
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
}
if config.Matrix != nil {
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
if err != nil {
return Config{}, fmt.Errorf("matrix: %w", err)
}
config.Matrix.ExpandedSets = expandedSets
} else {
config = AddDefaultGroupToConfig(config)
// Validate group members
memberUsage := make(map[string]string)
for groupID, groupConfig := range config.Groups {
prevSet := make(map[string]bool)
for _, member := range groupConfig.Members {
if _, found := prevSet[member]; found {
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
}
prevSet[member] = true
if existingGroup, exists := memberUsage[member]; exists {
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
}
memberUsage[member] = groupID
}
}
}
// Build the canonical Config.Routing from the effective result. Both legacy
// and new-style configs converge here. The Matrix pointer is shared so
// ExpandedSets stays in one place.
if config.Matrix != nil {
config.Routing.Router.Use = "matrix"
} else {
config.Routing.Router.Use = "group"
}
config.Routing.Router.Settings.Matrix = config.Matrix
config.Routing.Router.Settings.Groups = config.Groups
// This fork defaults to the "serial" scheduler: one model loaded at a time,
// requests served in strict arrival order. Set use: fifo for the upstream
// throughput-oriented behavior that batches same-model requests.
if config.Routing.Scheduler.Use == "" {
config.Routing.Scheduler.Use = "serial"
}
switch config.Routing.Scheduler.Use {
case "fifo", "serial":
default:
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo, serial)", config.Routing.Scheduler.Use)
}
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
if _, found := config.RealModelName(modelID); !found {
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
}
}
// Clean up hooks preload
if len(config.Hooks.OnStartup.Preload) > 0 {
var toPreload []string
for _, modelID := range config.Hooks.OnStartup.Preload {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
if real, found := config.RealModelName(modelID); found {
toPreload = append(toPreload, real)
}
}
config.Hooks.OnStartup.Preload = toPreload
}
// Validate API keys (env macros already substituted at string level)
for i, apikey := range config.RequiredAPIKeys {
if apikey == "" {
return Config{}, fmt.Errorf("empty api key found in apiKeys")
}
if strings.Contains(apikey, " ") {
return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i)
}
config.RequiredAPIKeys[i] = apikey
}
// Process peers with global macro substitution
for peerName, peerConfig := range config.Peers {
// Substitute global macros (LIFO order)
for i := len(config.Macros) - 1; i >= 0; i-- {
entry := config.Macros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
// Substitute in setParams (type-preserving)
if len(peerConfig.Filters.SetParams) > 0 {
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
}
peerConfig.Filters.SetParams = result.(map[string]any)
}
}
// Validate no unknown macros remain
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
}
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
}
if len(peerConfig.Filters.SetParams) > 0 {
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
return Config{}, err
}
}
config.Peers[peerName] = peerConfig
}
return config, nil
}
+198
View File
@@ -0,0 +1,198 @@
package config
import (
"fmt"
"os"
"regexp"
"strings"
"gopkg.in/yaml.v3"
)
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
)
// validateMacro validates macro name and value constraints
func validateMacro(name string, value any) error {
if len(name) >= 64 {
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
}
if !macroNameRegex.MatchString(name) {
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
}
// Validate that value is a scalar type
switch v := value.(type) {
case string:
// Check for self-reference
macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(v, macroSlug) {
return fmt.Errorf("macro '%s' contains self-reference", name)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
// These types are allowed
default:
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
}
switch name {
case "PORT", "MODEL_ID":
return fmt.Errorf("macro name '%s' is reserved", name)
}
return nil
}
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
func validateNestedForUnknownMacros(value any, context string) error {
switch v := value.(type) {
case string:
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
for _, match := range matches {
macroName := match[1]
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
}
// Check for unsubstituted env macros
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
for _, match := range envMatches {
varName := match[1]
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
}
return nil
case map[string]any:
for _, val := range v {
if err := validateNestedForUnknownMacros(val, context); err != nil {
return err
}
}
return nil
case []any:
for _, val := range v {
if err := validateNestedForUnknownMacros(val, context); err != nil {
return err
}
}
return nil
default:
// Scalar types don't contain macros
return nil
}
}
// substituteMacroInValue recursively substitutes a single macro in a value structure
// This is called once per macro, allowing LIFO substitution order
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
macroSlug := fmt.Sprintf("${%s}", macroName)
macroStr := fmt.Sprintf("%v", macroValue)
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if v == macroSlug {
return macroValue, nil
}
// Handle string interpolation
if strings.Contains(v, macroSlug) {
return strings.ReplaceAll(v, macroSlug, macroStr), nil
}
return v, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newMap[key] = newVal
}
return newMap, nil
case []any:
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
newSlice[i] = newVal
}
return newSlice, nil
default:
// Return scalar types as-is
return value, nil
}
}
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
// Returns error if any referenced env var is not set or contains invalid characters.
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
// (which strips comments) and only checking the comment-free version for macros.
func substituteEnvMacros(s string) (string, error) {
// Unmarshal and remarshal to strip YAML comments
var raw any
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
// If YAML is invalid, fall back to scanning the original string
// so the user gets the env var error rather than a confusing YAML parse error
return substituteEnvMacrosInString(s, s)
}
clean, err := yaml.Marshal(raw)
if err != nil {
return substituteEnvMacrosInString(s, s)
}
return substituteEnvMacrosInString(s, string(clean))
}
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
// them in target. This separation allows scanning comment-free YAML while
// substituting in the original string.
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
result := target
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
for _, match := range matches {
fullMatch := match[0] // ${env.VAR_NAME}
varName := match[1] // VAR_NAME
value, exists := os.LookupEnv(varName)
if !exists {
return "", fmt.Errorf("environment variable '%s' is not set", varName)
}
// Sanitize the value for safe YAML substitution
value, err := sanitizeEnvValueForYAML(value, varName)
if err != nil {
return "", err
}
result = strings.ReplaceAll(result, fullMatch, value)
}
return result, nil
}
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
// for compatibility with double-quoted YAML strings.
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
// Reject values that would break YAML structure regardless of quoting context
if strings.ContainsAny(value, "\n\r\x00") {
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
}
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
// In double-quoted contexts, they are interpreted correctly.
value = strings.ReplaceAll(value, `\`, `\\`)
value = strings.ReplaceAll(value, `"`, `\"`)
return value, nil
}
+300
View File
@@ -0,0 +1,300 @@
package config
import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"gopkg.in/yaml.v3"
)
// identityMapPaths is the set of dotted paths whose direct children are
// identity-keyed maps. A child key present in two sources is a hard error;
// such keys name discrete entities (a model, a group, a peer, etc.) and a
// duplicate means the user has split one entity across files by mistake.
var identityMapPaths = map[string]bool{
"models": true,
"groups": true,
"profiles": true,
"peers": true,
"matrix": true,
"routing.router.settings.groups": true,
"routing.router.settings.matrix": true,
}
// LoadConfigSources loads and merges configuration from -config (optional)
// and -config-dir (optional). At least one must be provided. The -config file
// is loaded first; *.yml/*.yaml files directly under -config-dir are then
// merged in sorted filename order. The merged document is passed through the
// existing LoadConfigFromReader pipeline unchanged.
func LoadConfigSources(configPath, configDir string) (Config, error) {
if configPath == "" && configDir == "" {
return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided")
}
var sourcePaths []string
if configPath != "" {
sourcePaths = append(sourcePaths, configPath)
}
if configDir != "" {
dirFiles, err := listYAMLFiles(configDir)
if err != nil {
return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err)
}
if configPath != "" {
absConfig, err := filepath.Abs(configPath)
if err != nil {
return Config{}, fmt.Errorf("failed to resolve -config path: %w", err)
}
for _, f := range dirFiles {
absF, err := filepath.Abs(f)
if err != nil {
return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err)
}
if absConfig == absF {
return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir)
}
}
}
sourcePaths = append(sourcePaths, dirFiles...)
}
if len(sourcePaths) == 0 {
return Config{}, fmt.Errorf("no configuration sources found")
}
var merged *yaml.Node
for _, p := range sourcePaths {
node, err := parseSource(p)
if err != nil {
return Config{}, err
}
if node == nil {
continue // empty file
}
if merged == nil {
merged = node
continue
}
if err := mergeNodes(merged, node, "", p); err != nil {
return Config{}, err
}
}
if merged == nil {
// All sources were empty; run the pipeline on empty input so defaults
// and validation still apply (e.g. startPort, performance defaults).
return LoadConfigFromReader(strings.NewReader(""))
}
out, err := yaml.Marshal(merged)
if err != nil {
return Config{}, fmt.Errorf("failed to marshal merged config: %w", err)
}
return LoadConfigFromReader(strings.NewReader(string(out)))
}
// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by
// filename for deterministic merge order. Subdirectories are not traversed.
func listYAMLFiles(dir string) ([]string, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
var files []string
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
continue
}
files = append(files, filepath.Join(dir, name))
}
sort.Strings(files)
return files, nil
}
// parseSource reads and parses one YAML config file into a root mapping node.
// Returns a nil node (no error) when the file is empty or contains only
// comments.
//
// Env macros (${env.VAR}) are substituted at the string level before YAML
// parsing so that flow-style constructs like [${env.API_KEY}] parse
// correctly — the brace would otherwise be interpreted as a flow mapping.
func parseSource(path string) (*yaml.Node, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config %s: %w", path, err)
}
yamlStr, err := substituteEnvMacros(string(data))
if err != nil {
return nil, fmt.Errorf("config %s: %w", path, err)
}
var doc yaml.Node
if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil {
return nil, fmt.Errorf("failed to parse config %s: %w", path, err)
}
// yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0]
// is the actual root. Unwrap it so callers see the real top-level node.
root := &doc
if root.Kind == yaml.DocumentNode && len(root.Content) > 0 {
root = root.Content[0]
}
if root.Kind == 0 || root.Content == nil {
return nil, nil
}
if root.Kind != yaml.MappingNode {
return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path)
}
return root, nil
}
// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in
// only one side are kept; shared keys are merged recursively under the rules
// in mergeValue. srcPath is included in error messages to identify the file
// that introduced the conflict.
func mergeNodes(dst, src *yaml.Node, path, srcPath string) error {
srcIdx := indexMapping(src)
// First pass: merge shared keys in place.
for i := 0; i+1 < len(dst.Content); i += 2 {
keyNode := dst.Content[i]
dstVal := dst.Content[i+1]
key := keyNode.Value
srcVal, ok := srcIdx[key]
if !ok {
continue // dst-only key, keep as-is
}
childPath := joinPath(path, key)
if identityMapPaths[childPath] {
// Identity-keyed map: each child key names a discrete entity
// (a model, group, peer, ...). A shared child key is a hard
// error; src-only children are appended in the second pass.
if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil {
return err
}
continue
}
if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil {
return err
}
}
// Second pass: append src-only keys.
dstIdx := indexMapping(dst)
for i := 0; i+1 < len(src.Content); i += 2 {
keyNode := src.Content[i]
srcVal := src.Content[i+1]
key := keyNode.Value
if _, ok := dstIdx[key]; ok {
continue // already merged above
}
keyCopy := *keyNode
valCopy := *srcVal
dst.Content = append(dst.Content, &keyCopy, &valCopy)
}
return nil
}
// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`,
// `groups`, `peers`). Any child key present in both sides is a duplicate
// entity and produces an error naming the conflicting key and source file.
// src-only keys are appended to dst.
func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error {
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath)
}
dstIdx := indexMapping(dst)
for i := 0; i+1 < len(src.Content); i += 2 {
keyNode := src.Content[i]
srcVal := src.Content[i+1]
key := keyNode.Value
if _, dup := dstIdx[key]; dup {
return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath)
}
keyCopy := *keyNode
valCopy := *srcVal
dst.Content = append(dst.Content, &keyCopy, &valCopy)
}
return nil
}
// mergeValue merges srcVal into dstVal (both pointing into the parent's
// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates;
// Scalar+Scalar errors on value mismatch; null on either side yields to the
// non-null side.
func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error {
switch {
case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode:
return mergeNodes(dstVal, srcVal, path, srcPath)
case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode:
dstVal.Content = append(dstVal.Content, srcVal.Content...)
return nil
case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode:
if isNullScalar(dstVal) {
*dstVal = *srcVal
return nil
}
if isNullScalar(srcVal) {
return nil
}
if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag {
return nil
}
return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath)
case isNull(dstVal):
*dstVal = *srcVal
return nil
case isNull(srcVal):
return nil
default:
return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath)
}
}
// isNull reports whether n represents a YAML null (empty or !!null).
func isNull(n *yaml.Node) bool {
if n == nil || n.Kind == 0 {
return true
}
return isNullScalar(n)
}
func isNullScalar(n *yaml.Node) bool {
return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == ""
}
// indexMapping builds a key -> value-node index for a mapping node.
func indexMapping(n *yaml.Node) map[string]*yaml.Node {
idx := make(map[string]*yaml.Node, len(n.Content)/2)
for i := 0; i+1 < len(n.Content); i += 2 {
idx[n.Content[i].Value] = n.Content[i+1]
}
return idx
}
func joinPath(parent, key string) string {
if parent == "" {
return key
}
return parent + "." + key
}
+304
View File
@@ -0,0 +1,304 @@
package config
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// writeYAML writes content to a file named name inside dir. Returns the full
// path of the written file.
func writeYAML(t *testing.T, dir, name, content string) string {
t.Helper()
p := filepath.Join(dir, name)
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755))
require.NoError(t, os.WriteFile(p, []byte(content), 0o644))
return p
}
// modelCfg builds a single-model YAML snippet indented for nesting under a
// `models:` key. The proxy uses a fixed port so tests don't depend on
// ${PORT} allocation.
func modelCfg(id, cmd string) string {
return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n"
}
func TestLoadConfigSources_NeitherProvided(t *testing.T) {
_, err := LoadConfigSources("", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "at least one of -config or -config-dir")
}
func TestLoadConfigSources_ConfigOnly(t *testing.T) {
dir := t.TempDir()
cfgPath := writeYAML(t, dir, "config.yaml", `
models:
`+modelCfg("model1", "echo hi")+`
groups:
group1:
members: ["model1"]
`)
cfg, err := LoadConfigSources(cfgPath, "")
require.NoError(t, err)
_, id, ok := cfg.FindConfig("model1")
require.True(t, ok)
assert.Equal(t, "model1", id)
}
func TestLoadConfigSources_DirOnly(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a"))
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b"))
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
for _, want := range []string{"alpha", "beta"} {
_, _, ok := cfg.FindConfig(want)
assert.True(t, ok, "model %s should be present", want)
}
}
func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) {
// -config lives outside -config-dir; both contribute models additively.
dir := t.TempDir()
cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base"))
cfgDir := t.TempDir()
writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext"))
cfg, err := LoadConfigSources(cfgPath, cfgDir)
require.NoError(t, err)
for _, want := range []string{"base", "ext"} {
_, _, ok := cfg.FindConfig(want)
assert.True(t, ok, "model %s should be present after merge", want)
}
}
// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that
// is also a member of -config-dir is rejected.
func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) {
dir := t.TempDir()
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base"))
_, err := LoadConfigSources(cfgPath, dir)
require.Error(t, err)
assert.Contains(t, err.Error(), "is also present in -config-dir")
}
func TestLoadConfigSources_DuplicateModelID(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a"))
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b"))
_, err := LoadConfigSources("", dir)
require.Error(t, err)
assert.Contains(t, err.Error(), `duplicate models "dup"`)
}
func TestLoadConfigSources_DuplicateGroupID(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", `
models:
`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n")
writeYAML(t, dir, "b.yaml", `
models:
`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n")
_, err := LoadConfigSources("", dir)
require.Error(t, err)
assert.Contains(t, err.Error(), `duplicate groups "g1"`)
}
func TestLoadConfigSources_DuplicatePeer(t *testing.T) {
dir := t.TempDir()
peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n"
peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n"
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA)
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB)
_, err := LoadConfigSources("", dir)
require.Error(t, err)
assert.Contains(t, err.Error(), `duplicate peers "remote"`)
}
func TestLoadConfigSources_ScalarConflict(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n")
_, err := LoadConfigSources("", dir)
require.Error(t, err)
assert.Contains(t, err.Error(), `conflict at "globalTTL"`)
}
func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n")
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
assert.Equal(t, 100, cfg.GlobalTTL)
}
func TestLoadConfigSources_MacrosConcatenate(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}"))
writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}"))
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
// Both macros are available globally after merge.
low, ok := cfg.Macros.Get("LOW")
require.True(t, ok)
assert.Equal(t, 1, low)
high, ok := cfg.Macros.Get("HIGH")
require.True(t, ok)
assert.Equal(t, 2, high)
}
func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n")
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n")
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys)
}
func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", `
models:
`+modelCfg("m1", "echo m1")+`
routing:
router:
settings:
groups:
groupA:
members: [m1]
`)
writeYAML(t, dir, "b.yaml", `
models:
`+modelCfg("m2", "echo m2")+`
routing:
router:
settings:
groups:
groupB:
members: [m2]
`)
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
groups := cfg.Routing.Router.Settings.Groups
assert.Contains(t, groups, "groupA")
assert.Contains(t, groups, "groupB")
// default group added by pipeline for orphaned/leftover routing groups...
// here both groups reference distinct models
}
func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) {
dir := t.TempDir()
// Use ${PORT} in cmd so the pipeline allocates a port and substitutes it;
// verifies env/macro substitution runs on the merged document.
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n")
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
m := cfg.Models["m1"]
assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted")
assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy")
}
func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) {
// Regression: flow-style lists with ${env.*} must parse. Previously
// parseSource unmarshalled before env substitution, so the brace in
// [${env.API_KEY}] was misread as a flow mapping and parsing failed.
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n")
writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n")
t.Setenv("TEST_API_KEY", "secret123")
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
assert.Contains(t, cfg.RequiredAPIKeys, "secret123")
}
func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) {
// Two files defining distinct models, scanned in z..a order by filename.
// Determine merged result is the same regardless of how the FS returns them.
dir := t.TempDir()
writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z"))
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a"))
const runs = 3
for i := 0; i < runs; i++ {
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
// startPort-based allocation: first allocated model gets 5800.
// Sorted order means amodel gets 5800, zmodel gets 5801.
_, _, ok := cfg.FindConfig("amodel")
assert.True(t, ok)
_, _, ok = cfg.FindConfig("zmodel")
assert.True(t, ok)
}
}
func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) {
dir := t.TempDir()
cfgDir := t.TempDir()
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1"))
cfg, err := LoadConfigSources(cfgPath, cfgDir)
require.NoError(t, err)
assert.Contains(t, cfg.Models, "m1")
}
func TestLoadConfigSources_EmptyDirOnly(t *testing.T) {
// An empty -config-dir with no -config is an error: there is nothing to
// load and silently producing an empty config would mask the misconfig.
cfgDir := t.TempDir()
_, err := LoadConfigSources("", cfgDir)
require.Error(t, err)
assert.Contains(t, err.Error(), "no configuration sources found")
}
func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) {
// Macros defined in one file should not satisfy unknown-macro validation in
// another — they do, because merge concats global macros before validation
// runs. This test documents that a macro from file A is usable in file B.
dir := t.TempDir()
writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy"))
writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}"))
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
m := cfg.Models["user"]
assert.Contains(t, m.Cmd, "hello")
assert.NotContains(t, m.Cmd, "${SHARED}")
}
func TestLoadConfigSources_KindMismatchErrors(t *testing.T) {
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1"))
writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2"))
_, err := LoadConfigSources("", dir)
require.Error(t, err)
assert.Contains(t, err.Error(), "incompatible YAML node kinds")
}
func TestLoadConfigSources_NullYieldsToValue(t *testing.T) {
// File A: routing.router block absent (null on root for routing);
// file B: defines routing.router.settings.groups. Merge should keep B's.
dir := t.TempDir()
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1"))
writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2"))
cfg, err := LoadConfigSources("", dir)
require.NoError(t, err)
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
}
+55
View File
@@ -0,0 +1,55 @@
package config
import (
"fmt"
"regexp"
"gopkg.in/yaml.v3"
)
// DefaultUpstreamIgnorePathsPattern is the default regular expression applied
// to upstream.ignorePaths when the section is empty or absent from the config.
// It matches common static-asset suffixes so requests for .js/.css/.png/etc.
// files do not trigger a model swap.
const DefaultUpstreamIgnorePathsPattern = `.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$`
// DefaultUpstreamIgnorePaths returns the default compiled ignore paths used
// when upstream.ignorePaths is not specified in the config. The returned slice
// is fresh so callers may mutate it without affecting other configs.
func DefaultUpstreamIgnorePaths() []*regexp.Regexp {
return []*regexp.Regexp{regexp.MustCompile(DefaultUpstreamIgnorePathsPattern)}
}
// UpstreamConfig controls behaviour of the /upstream passthrough endpoint.
type UpstreamConfig struct {
// IgnorePaths is a slice of compiled regular expressions. Any request to
// /upstream/<model>/<path> whose remaining path matches any of these
// expressions will be ignored and not trigger a swap. When the config
// does not specify any patterns, DefaultUpstreamIgnorePaths is applied.
IgnorePaths []*regexp.Regexp `yaml:"-"`
}
// rawUpstreamConfig is the intermediate form used to unmarshal the YAML into
// plain strings, which are then compiled into *regexp.Regexp.
type rawUpstreamConfig struct {
IgnorePaths []string `yaml:"ignorePaths"`
}
// UnmarshalYAML compiles each ignorePaths entry into a *regexp.Regexp. If any
// entry fails to compile, an error is returned.
func (u *UpstreamConfig) UnmarshalYAML(value *yaml.Node) error {
var raw rawUpstreamConfig
if err := value.Decode(&raw); err != nil {
return err
}
patterns := make([]*regexp.Regexp, 0, len(raw.IgnorePaths))
for _, p := range raw.IgnorePaths {
re, err := regexp.Compile(p)
if err != nil {
return fmt.Errorf("upstream.ignorePaths: invalid regular expression %q: %w", p, err)
}
patterns = append(patterns, re)
}
u.IgnorePaths = patterns
return nil
}
+88
View File
@@ -0,0 +1,88 @@
package config
import (
"regexp"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const upstreamConfigHeader = `
models:
model1:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080"
`
func TestConfig_UpstreamIgnorePaths_DefaultWhenAbsent(t *testing.T) {
// When upstream is not specified at all, the default pattern is applied.
content := upstreamConfigHeader
cfg, err := LoadConfigFromReader(strings.NewReader(content))
require.NoError(t, err)
require.Len(t, cfg.Upstream.IgnorePaths, 1)
def := cfg.Upstream.IgnorePaths[0]
assert.IsType(t, &regexp.Regexp{}, def)
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, def.String())
// The default matches common static-asset suffixes.
assert.True(t, def.MatchString("/foo.js"))
assert.True(t, def.MatchString("/bar/baz.json"))
assert.True(t, def.MatchString("/static/img.png"))
assert.True(t, def.MatchString("/notes.txt"))
assert.True(t, def.MatchString("/favicon.ico"))
// And does not match inference API paths.
assert.False(t, def.MatchString("/v1/chat/completions"))
assert.False(t, def.MatchString("/v1/models"))
assert.False(t, def.MatchString("/health"))
}
func TestConfig_UpstreamIgnorePaths_DefaultWhenSectionEmpty(t *testing.T) {
// When upstream is present but ignorePaths is omitted, the default is still
// applied.
content := `upstream: {}` + "\n" + upstreamConfigHeader
cfg, err := LoadConfigFromReader(strings.NewReader(content))
require.NoError(t, err)
require.Len(t, cfg.Upstream.IgnorePaths, 1)
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, cfg.Upstream.IgnorePaths[0].String())
}
func TestConfig_UpstreamIgnorePaths_Compiles(t *testing.T) {
content := `
upstream:
ignorePaths:
- ".*\\.(js|json|css|png|gif|jpg|jpeg|txt)$"
- "^/static/.*"
` + upstreamConfigHeader
cfg, err := LoadConfigFromReader(strings.NewReader(content))
require.NoError(t, err)
require.Len(t, cfg.Upstream.IgnorePaths, 2)
// Verify the patterns are compiled into *regexp.Regexp and match as expected.
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/foo.js"))
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/bar/baz.json"))
assert.False(t, cfg.Upstream.IgnorePaths[0].MatchString("/v1/chat/completions"))
assert.True(t, cfg.Upstream.IgnorePaths[1].MatchString("/static/foo.png"))
assert.False(t, cfg.Upstream.IgnorePaths[1].MatchString("/v1/chat/completions"))
// Confirm the type is *regexp.Regexp to satisfy the API contract.
for _, re := range cfg.Upstream.IgnorePaths {
assert.IsType(t, &regexp.Regexp{}, re)
}
}
func TestConfig_UpstreamIgnorePaths_InvalidRegexReturnsError(t *testing.T) {
content := `
upstream:
ignorePaths:
- "[invalid("
` + upstreamConfigHeader
_, err := LoadConfigFromReader(strings.NewReader(content))
require.Error(t, err)
assert.Contains(t, err.Error(), "upstream.ignorePaths")
assert.Contains(t, err.Error(), "invalid regular expression")
}
+92
View File
@@ -0,0 +1,92 @@
package perf
type LUID struct {
LowPart uint32
HighPart int32
}
const maxEnumAdapters = 16
type D3DKMT_ENUMADAPTERS2 struct {
NumAdapters uint32
pAdapters uintptr
}
type D3DKMT_ADAPTERINFO struct {
hAdapter uint32
AdapterLuid LUID
NumOfSources uint32
bPresentMoveRegionsPreferred int32
}
type D3DKMT_OPENADAPTERFROMLUID struct {
AdapterLuid LUID
hAdapter uint32
}
type D3DKMT_CLOSEADAPTER struct {
hAdapter uint32
}
type KMTQUERYADAPTERINFOTYPE int32
const (
KMTQAITYPE_UMDRIVERPRIVATE KMTQUERYADAPTERINFOTYPE = 0
KMTQAITYPE_ADAPTERREGISTRYINFO KMTQUERYADAPTERINFOTYPE = 8
KMTQAITYPE_DRIVERVERSION KMTQUERYADAPTERINFOTYPE = 13
KMTQAITYPE_PHYSICALADAPTERDEVICEIDS KMTQUERYADAPTERINFOTYPE = 31
KMTQAITYPE_NODEPERFDATA KMTQUERYADAPTERINFOTYPE = 61
KMTQAITYPE_ADAPTERPERFDATA KMTQUERYADAPTERINFOTYPE = 62
KMTQAITYPE_ADAPTERPERFDATA_CAPS KMTQUERYADAPTERINFOTYPE = 63
)
type D3DKMT_QUERYADAPTERINFO struct {
hAdapter uint32
Type KMTQUERYADAPTERINFOTYPE
pPrivateDriverData uintptr
PrivateDriverDataSize uint32
}
type D3DKMT_ADAPTER_PERFDATA struct {
PhysicalAdapterIndex uint32
MemoryFrequency uint64
MaxMemoryFrequency uint64
MaxMemoryFrequencyOC uint64
MemoryBandwidth uint64
PCIEBandwidth uint64
FanRPM uint32
Power uint32
Temperature uint32
PowerStateOverride byte
}
type D3DKMT_QUERYSTATISTICS_TYPE int32
const (
D3DKMT_QUERYSTATISTICS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 0
D3DKMT_QUERYSTATISTICS_PROCESS D3DKMT_QUERYSTATISTICS_TYPE = 1
D3DKMT_QUERYSTATISTICS_PROCESS_ADAPTER D3DKMT_QUERYSTATISTICS_TYPE = 2
D3DKMT_QUERYSTATISTICS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 3
D3DKMT_QUERYSTATISTICS_PROCESS_SEGMENT D3DKMT_QUERYSTATISTICS_TYPE = 4
D3DKMT_QUERYSTATISTICS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 5
D3DKMT_QUERYSTATISTICS_PROCESS_NODE D3DKMT_QUERYSTATISTICS_TYPE = 6
D3DKMT_QUERYSTATISTICS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 7
D3DKMT_QUERYSTATISTICS_PROCESS_VIDPNSOURCE D3DKMT_QUERYSTATISTICS_TYPE = 8
)
type D3DKMT_ADAPTER_PERFDATACAPS struct {
PhysicalAdapterIndex uint32
MaxMemoryBandwidth uint64
MaxPCIEBandwidth uint64
MaxFanRPM uint32
TemperatureMax uint32
TemperatureWarning uint32
}
type D3DKMT_QUERYSTATISTICS_QUERY_SEGMENT struct {
SegmentId uint32
}
type D3DKMT_QUERYSTATISTICS_QUERY_NODE struct {
NodeId uint32
}
+529
View File
@@ -0,0 +1,529 @@
//go:build windows
package perf
import (
"context"
"encoding/binary"
"fmt"
"sync"
"time"
"unsafe"
"github.com/mostlygeek/llama-swap/internal/logmon"
"golang.org/x/sys/windows"
)
var (
d3dkmDLL *windows.LazyDLL
procEnumAdapters2 *windows.LazyProc
procOpenAdapterFromLuid *windows.LazyProc
procCloseAdapter *windows.LazyProc
procQueryAdapterInfo *windows.LazyProc
procQueryStatistics *windows.LazyProc
d3dkmtInitOnce sync.Once
d3dkmtInitErr error
)
// initD3DKMT lazily loads gdi32.dll and resolves D3DKMT function pointers.
// Safe for concurrent use via sync.Once.
func initD3DKMT() error {
d3dkmtInitOnce.Do(func() {
d3dkmDLL = windows.NewLazySystemDLL("gdi32.dll")
procEnumAdapters2 = d3dkmDLL.NewProc("D3DKMTEnumAdapters2")
procOpenAdapterFromLuid = d3dkmDLL.NewProc("D3DKMTOpenAdapterFromLuid")
procCloseAdapter = d3dkmDLL.NewProc("D3DKMTCloseAdapter")
procQueryAdapterInfo = d3dkmDLL.NewProc("D3DKMTQueryAdapterInfo")
procQueryStatistics = d3dkmDLL.NewProc("D3DKMTQueryStatistics")
for name, p := range map[string]*windows.LazyProc{
"D3DKMTEnumAdapters2": procEnumAdapters2,
"D3DKMTOpenAdapterFromLuid": procOpenAdapterFromLuid,
"D3DKMTCloseAdapter": procCloseAdapter,
"D3DKMTQueryAdapterInfo": procQueryAdapterInfo,
"D3DKMTQueryStatistics": procQueryStatistics,
} {
if err := p.Find(); err != nil {
d3dkmtInitErr = fmt.Errorf("D3DKMT %s not found: %w", name, err)
return
}
}
})
return d3dkmtInitErr
}
// ntstatusCall invokes a D3DKMT function and returns a non-nil error if the
// NTSTATUS result is not STATUS_SUCCESS (0).
func ntstatusCall(proc *windows.LazyProc, arg unsafe.Pointer) error {
ret, _, _ := proc.Call(uintptr(arg))
if ret != 0 {
return fmt.Errorf("NTSTATUS 0x%08x", uint32(ret))
}
return nil
}
// d3dkmEnumerateAdapters enumerates all available graphics adapters via
// D3DKMTEnumAdapters2.
func d3dkmEnumerateAdapters() ([]D3DKMT_ADAPTERINFO, error) {
var adapters [maxEnumAdapters]D3DKMT_ADAPTERINFO
enum := D3DKMT_ENUMADAPTERS2{
NumAdapters: maxEnumAdapters,
pAdapters: uintptr(unsafe.Pointer(&adapters[0])),
}
if err := ntstatusCall(procEnumAdapters2, unsafe.Pointer(&enum)); err != nil {
return nil, fmt.Errorf("EnumAdapters2: %w", err)
}
if enum.NumAdapters == 0 {
return nil, fmt.Errorf("no adapters found")
}
result := make([]D3DKMT_ADAPTERINFO, enum.NumAdapters)
for i := uint32(0); i < enum.NumAdapters; i++ {
result[i] = adapters[i]
}
return result, nil
}
// d3dkmOpenAdapter opens a D3DKMT adapter handle for the given LUID.
func d3dkmOpenAdapter(luid LUID) (uint32, error) {
req := D3DKMT_OPENADAPTERFROMLUID{
AdapterLuid: luid,
}
if err := ntstatusCall(procOpenAdapterFromLuid, unsafe.Pointer(&req)); err != nil {
return 0, fmt.Errorf("OpenAdapterFromLuid: %w", err)
}
return req.hAdapter, nil
}
// d3dkmCloseAdapter closes a previously opened D3DKMT adapter handle.
func d3dkmCloseAdapter(hAdapter uint32) error {
req := D3DKMT_CLOSEADAPTER{hAdapter: hAdapter}
return ntstatusCall(procCloseAdapter, unsafe.Pointer(&req))
}
// d3dkmGetAdapterPerfData queries per-adapter performance data (temperature,
// fan RPM, power, bandwidth) via KMTQAITYPE_ADAPTERPERFDATA.
func d3dkmGetAdapterPerfData(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATA, error) {
var data D3DKMT_ADAPTER_PERFDATA
req := D3DKMT_QUERYADAPTERINFO{
hAdapter: hAdapter,
Type: KMTQAITYPE_ADAPTERPERFDATA,
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
}
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATA): %w", err)
}
return &data, nil
}
// d3dkmGetAdapterPerfDataCaps queries static adapter performance capabilities
// (max fan RPM, temperature limits, max bandwidth) via KMTQAITYPE_ADAPTERPERFDATA_CAPS.
func d3dkmGetAdapterPerfDataCaps(hAdapter uint32) (*D3DKMT_ADAPTER_PERFDATACAPS, error) {
var data D3DKMT_ADAPTER_PERFDATACAPS
req := D3DKMT_QUERYADAPTERINFO{
hAdapter: hAdapter,
Type: KMTQAITYPE_ADAPTERPERFDATA_CAPS,
pPrivateDriverData: uintptr(unsafe.Pointer(&data)),
PrivateDriverDataSize: uint32(unsafe.Sizeof(data)),
}
if err := ntstatusCall(procQueryAdapterInfo, unsafe.Pointer(&req)); err != nil {
return nil, fmt.Errorf("QueryAdapterInfo(ADAPTERPERFDATACAPS): %w", err)
}
return &data, nil
}
type queryStatsBuffer struct {
Type int32 // offset 0
AdapterLuid LUID // offset 4
hProcess uintptr // offset 16
// _result mirrors the D3DKMT_QUERYSTATISTICS_RESULT union.
// sizeof(D3DKMT_QUERYSTATISTICS) == 0x328 (808 bytes) on x64.
//
// The C struct layout (x64):
// offset 0: Type (int32, 4 bytes)
// offset 4: AdapterLuid (LUID, 8 bytes)
// offset 12: 4 bytes padding (for 8-byte alignment of hProcess)
// offset 16: hProcess (HANDLE, 8 bytes)
// offset 24: QueryResult (union, 780 bytes — largest member is AdapterInformation)
// offset 804: anonymous input union (QueryNode.NodeId / QuerySegment.SegmentId, 4 bytes)
//
// Previous bug: _result was [776]byte, placing QueryId at offset 800 instead of 804.
// The kernel read NodeId/SegmentId from offset 804 (always zero from _pad),
// causing all NODE and SEGMENT queries to use index 0 regardless of the value
// passed in QueryId. This produced alternating behavior where only GPU util OR
// memory util appeared to work, depending on which test variant happened to put
// non-zero data near offset 804 in the result buffer.
_result [780]byte // offset 24, size 780 — places QueryId at offset 804
QueryId int32 // offset 804 — matches C anonymous union for NodeId/SegmentId
}
func init() {
var buf queryStatsBuffer
if unsafe.Sizeof(buf) != 808 {
panic(fmt.Sprintf("queryStatsBuffer size %d != expected 808 (sizeof D3DKMT_QUERYSTATISTICS on x64)", unsafe.Sizeof(buf)))
}
if unsafe.Offsetof(buf.QueryId) != 804 {
panic(fmt.Sprintf("queryStatsBuffer.QueryId offset %d != expected 804 (C anonymous union offset)", unsafe.Offsetof(buf.QueryId)))
}
var perfData D3DKMT_ADAPTER_PERFDATA
if unsafe.Sizeof(perfData) != 64 {
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATA size %d != expected 64 on x64", unsafe.Sizeof(perfData)))
}
var caps D3DKMT_ADAPTER_PERFDATACAPS
if unsafe.Sizeof(caps) != 40 {
panic(fmt.Sprintf("D3DKMT_ADAPTER_PERFDATACAPS size %d != expected 40 on x64", unsafe.Sizeof(caps)))
}
}
const (
qsoffsetNbSegments = 0
qsoffsetNodeCount = 4
qsoffsetCommitLimit = 0
qsoffsetBytesCommitted = 8
qsoffsetBytesResident = 16
qsoffsetRunningTime = 0
qsoffsetSystemRunningTime = 272
)
// d3dkmQueryAdapterStats returns the number of memory segments and compute
// nodes for the adapter identified by luid.
func d3dkmQueryAdapterStats(luid LUID) (nbSegments uint32, nodeCount uint32, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_ADAPTER),
AdapterLuid: luid,
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(ADAPTER): %w", err)
}
nbSegments = binary.LittleEndian.Uint32(buf._result[qsoffsetNbSegments : qsoffsetNbSegments+4])
nodeCount = binary.LittleEndian.Uint32(buf._result[qsoffsetNodeCount : qsoffsetNodeCount+4])
return nbSegments, nodeCount, nil
}
// d3dkmQuerySegmentStats returns the commit limit (total) and resident
// (used) bytes for the given memory segment of an adapter.
func d3dkmQuerySegmentStats(luid LUID, segmentID uint32) (commitLimit uint64, bytesResident uint64, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_SEGMENT),
AdapterLuid: luid,
QueryId: int32(segmentID),
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(SEGMENT %d): %w", segmentID, err)
}
commitLimit = binary.LittleEndian.Uint64(buf._result[qsoffsetCommitLimit : qsoffsetCommitLimit+8])
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesResident : qsoffsetBytesResident+8])
if bytesResident == 0 {
bytesResident = binary.LittleEndian.Uint64(buf._result[qsoffsetBytesCommitted : qsoffsetBytesCommitted+8])
}
return commitLimit, bytesResident, nil
}
// d3dkmQueryNodeStats returns the global and system running time counters
// (in 100ns units) for the given compute node of an adapter.
func d3dkmQueryNodeStats(luid LUID, nodeID uint32) (runningTime uint64, systemRunningTime uint64, err error) {
buf := queryStatsBuffer{
Type: int32(D3DKMT_QUERYSTATISTICS_NODE),
AdapterLuid: luid,
QueryId: int32(nodeID),
}
if err := ntstatusCall(procQueryStatistics, unsafe.Pointer(&buf)); err != nil {
return 0, 0, fmt.Errorf("QueryStatistics(NODE %d): %w", nodeID, err)
}
runningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetRunningTime : qsoffsetRunningTime+8])
systemRunningTime = binary.LittleEndian.Uint64(buf._result[qsoffsetSystemRunningTime : qsoffsetSystemRunningTime+8])
return runningTime, systemRunningTime, nil
}
type nodeRunningTimes struct {
Global uint64
System uint64
}
// d3dkmtNodeUtil computes GPU node utilization as a percentage from running
// time deltas. Returns -1 if counters went backwards (wrap/reset), 0 if idle.
func d3dkmtNodeUtil(prevRT, curRT nodeRunningTimes, elapsed100ns int64) float64 {
if curRT.Global < prevRT.Global || curRT.System < prevRT.System {
return -1
}
gd := curRT.Global - prevRT.Global
sd := curRT.System - prevRT.System
if gd > 0 && sd > 0 {
util := float64(gd) / float64(sd)
if util > 1.0 {
util = 1.0
}
return util * 100.0
} else if gd > 0 && elapsed100ns > 0 {
util := float64(gd) / float64(elapsed100ns) * 100.0
if util > 100.0 {
util = 100.0
}
return util
}
return 0
}
// d3dkmtFanPct returns fan speed as a percentage of maxFanRPM, clamped to
// 100%. Returns 0 if maxFanRPM is unavailable or fan is not spinning.
func d3dkmtFanPct(fanRPM, maxFanRPM uint32) float64 {
if maxFanRPM > 0 && fanRPM > 0 {
pct := float64(fanRPM) / float64(maxFanRPM) * 100.0
if pct > 100.0 {
pct = 100.0
}
return pct
}
return 0
}
// d3dkmtPowerW converts power from deci-watts (as reported by D3DKMT) to
// watts. Returns 0 if the power value is zero.
func d3dkmtPowerW(power uint32) float64 {
if power > 0 {
return float64(power) / 10.0
}
return 0
}
// d3dkmtTempC converts temperature from deci-Celsius (as reported by D3DKMT)
// to degrees Celsius.
func d3dkmtTempC(tempDeciC uint32) int {
return int(tempDeciC / 10)
}
type d3dkmtAdapterState struct {
luid LUID
hAdapter uint32
nbSegments uint32
nodeCount uint32
maxFanRPM uint32
prevNodeRT map[uint32]nodeRunningTimes
prevTime time.Time
}
// tryD3DKMT attempts to start GPU monitoring using D3DKMT and optional PDH
// counters. It returns a channel of GpuStat snapshots or an error if no
// usable adapters are found.
func tryD3DKMT(ctx context.Context, every time.Duration, logger *logmon.Monitor) (chan []GpuStat, error) {
if err := initD3DKMT(); err != nil {
return nil, err
}
adapterInfos, err := d3dkmEnumerateAdapters()
if err != nil {
return nil, err
}
type adapterMeta struct {
luid LUID
nbSegments uint32
nodeCount uint32
maxFanRPM uint32
}
var metaList []adapterMeta
for i, ai := range adapterInfos {
hAdapter, err := d3dkmOpenAdapter(ai.AdapterLuid)
if err != nil {
logger.Debugf("adapter %d: open failed: %s", i, err.Error())
continue
}
nbSegments, nodeCount, err := d3dkmQueryAdapterStats(ai.AdapterLuid)
if err != nil {
logger.Debugf("adapter %d: query stats failed: %s", i, err.Error())
d3dkmCloseAdapter(hAdapter)
continue
}
caps, err := d3dkmGetAdapterPerfDataCaps(hAdapter)
if err != nil {
logger.Debugf("adapter %d: perf caps failed: %s", i, err.Error())
}
d3dkmCloseAdapter(hAdapter)
var maxFanRPM uint32
if caps != nil {
maxFanRPM = caps.MaxFanRPM
}
metaList = append(metaList, adapterMeta{
luid: ai.AdapterLuid,
nbSegments: nbSegments,
nodeCount: nodeCount,
maxFanRPM: maxFanRPM,
})
logger.Debugf("adapter %d: segments=%d nodes=%d fan_max=%d luid=%d:%d", i, nbSegments, nodeCount, maxFanRPM, ai.AdapterLuid.HighPart, ai.AdapterLuid.LowPart)
}
if len(metaList) == 0 {
return nil, fmt.Errorf("no usable D3DKMT adapters found")
}
pdhUtil, pdhErr := initPdhGpuUtil()
if pdhErr != nil {
logger.Debugf("PDH GPU utilization not available: %s", pdhErr.Error())
} else {
logger.Info("using PDH performance counters for GPU utilization")
}
ch := make(chan []GpuStat, 1)
go func() {
defer close(ch)
if pdhUtil != nil {
defer pdhUtil.close()
}
var adapters []d3dkmtAdapterState
for _, m := range metaList {
hAdapter, err := d3dkmOpenAdapter(m.luid)
if err != nil {
logger.Debugf("reopen adapter failed: %s", err.Error())
continue
}
adapters = append(adapters, d3dkmtAdapterState{
luid: m.luid,
hAdapter: hAdapter,
nbSegments: m.nbSegments,
nodeCount: m.nodeCount,
maxFanRPM: m.maxFanRPM,
prevNodeRT: make(map[uint32]nodeRunningTimes),
})
}
if len(adapters) == 0 {
return
}
defer func() {
for _, a := range adapters {
d3dkmCloseAdapter(a.hAdapter)
}
}()
for i := range adapters {
a := &adapters[i]
for node := uint32(0); node < a.nodeCount; node++ {
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
if err != nil {
continue
}
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
}
a.prevTime = time.Now()
}
ticker := time.NewTicker(every)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
stats := make([]GpuStat, 0, len(adapters))
now := time.Now()
var pdhUtilMap map[LUID]float64
if pdhUtil != nil {
pdhUtilMap = pdhUtil.collect()
}
for i := range adapters {
a := &adapters[i]
perfData, err := d3dkmGetAdapterPerfData(a.hAdapter)
if err != nil {
logger.Debugf("adapter %d perfdata: %s", i, err.Error())
continue
}
var memUsedMB, memTotalMB int
for seg := uint32(0); seg < a.nbSegments; seg++ {
limit, resident, err := d3dkmQuerySegmentStats(a.luid, seg)
if err != nil {
continue
}
memUsedMB += int(resident / (1024 * 1024))
memTotalMB += int(limit / (1024 * 1024))
}
var gpuUtil float64
pdhGaveValue := false
if pdhUtilMap != nil {
if util, ok := pdhUtilMap[a.luid]; ok {
gpuUtil = util
pdhGaveValue = true
}
}
if !pdhGaveValue && a.nodeCount > 0 {
elapsedNs := now.Sub(a.prevTime).Nanoseconds()
elapsed100ns := elapsedNs / 100
for node := uint32(0); node < a.nodeCount; node++ {
globalRT, systemRT, err := d3dkmQueryNodeStats(a.luid, node)
if err != nil {
continue
}
if prevRT, ok := a.prevNodeRT[node]; ok {
if globalRT < prevRT.Global || systemRT < prevRT.System {
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
continue
}
nodeUtil := d3dkmtNodeUtil(prevRT, nodeRunningTimes{Global: globalRT, System: systemRT}, elapsed100ns)
if nodeUtil > gpuUtil {
gpuUtil = nodeUtil
}
}
a.prevNodeRT[node] = nodeRunningTimes{Global: globalRT, System: systemRT}
}
a.prevTime = now
}
tempC := d3dkmtTempC(perfData.Temperature)
fanSpeedPct := d3dkmtFanPct(perfData.FanRPM, a.maxFanRPM)
powerDrawW := d3dkmtPowerW(perfData.Power)
var memUtilPct float64
if memTotalMB > 0 {
memUtilPct = float64(memUsedMB) / float64(memTotalMB) * 100.0
}
stats = append(stats, GpuStat{
Timestamp: now,
ID: i,
Name: fmt.Sprintf("GPU %d", i),
TempC: tempC,
GpuUtilPct: gpuUtil,
MemUtilPct: memUtilPct,
MemUsedMB: memUsedMB,
MemTotalMB: memTotalMB,
FanSpeedPct: fanSpeedPct,
PowerDrawW: powerDrawW,
})
}
if len(stats) > 0 {
select {
case ch <- stats:
default:
}
}
}
}
}()
return ch, nil
}
+98
View File
@@ -0,0 +1,98 @@
//go:build windows
package perf
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestD3dkmtNodeUtil_FullLoad(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 5000, System: 14000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 100.0, got)
}
func TestD3dkmtNodeUtil_PartialUtil(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 3000, System: 14000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 50.0, got)
}
func TestD3dkmtNodeUtil_Identical(t *testing.T) {
prev := nodeRunningTimes{Global: 10000, System: 10000}
cur := nodeRunningTimes{Global: 20000, System: 20000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 100.0, got)
}
func TestD3dkmtNodeUtil_CounterWrap(t *testing.T) {
prev := nodeRunningTimes{Global: 9000, System: 10000}
cur := nodeRunningTimes{Global: 1000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, -1.0, got)
}
func TestD3dkmtNodeUtil_SystemWrap(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 9000}
cur := nodeRunningTimes{Global: 5000, System: 1000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, -1.0, got)
}
func TestD3dkmtNodeUtil_ZeroDelta(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 1000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 100000)
assert.Equal(t, 0.0, got)
}
func TestD3dkmtNodeUtil_ElapsedFallback(t *testing.T) {
prev := nodeRunningTimes{Global: 1000, System: 10000}
cur := nodeRunningTimes{Global: 6000, System: 10000}
got := d3dkmtNodeUtil(prev, cur, 50000)
assert.InDelta(t, 10.0, got, 0.01)
}
func TestD3dkmtFanPct_Normal(t *testing.T) {
assert.Equal(t, 50.0, d3dkmtFanPct(1500, 3000))
}
func TestD3dkmtFanPct_MaxFan(t *testing.T) {
assert.Equal(t, 100.0, d3dkmtFanPct(3000, 3000))
}
func TestD3dkmtFanPct_OverMaxClamped(t *testing.T) {
assert.Equal(t, 100.0, d3dkmtFanPct(4000, 3000))
}
func TestD3dkmtFanPct_ZeroMaxFan(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(1500, 0))
}
func TestD3dkmtFanPct_ZeroFanRPM(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(0, 3000))
}
func TestD3dkmtFanPct_BothZero(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtFanPct(0, 0))
}
func TestD3dkmtPowerW(t *testing.T) {
assert.Equal(t, 250.0, d3dkmtPowerW(2500))
}
func TestD3dkmtPowerW_Zero(t *testing.T) {
assert.Equal(t, 0.0, d3dkmtPowerW(0))
}
func TestD3dkmtTempC(t *testing.T) {
assert.Equal(t, 65, d3dkmtTempC(650))
}
func TestD3dkmtTempC_Zero(t *testing.T) {
assert.Equal(t, 0, d3dkmtTempC(0))
}
+7
View File
@@ -22,6 +22,13 @@ func getGpuStats(ctx context.Context, every time.Duration, logger *logmon.Monito
logger.Debugf("nvidia-smi: %s", err.Error()) logger.Debugf("nvidia-smi: %s", err.Error())
} }
if ch, err := tryD3DKMT(ctx, every, logger); err == nil {
logger.Info("using D3DKMT for GPU monitoring")
return ch, nil
} else {
logger.Debugf("D3DKMT: %s", err.Error())
}
return nil, ErrNoGpuTool return nil, ErrNoGpuTool
} }
+159
View File
@@ -0,0 +1,159 @@
//go:build windows
package perf
import (
"fmt"
"strconv"
"strings"
"unsafe"
"golang.org/x/sys/windows"
)
var (
pdhDLL = windows.NewLazySystemDLL("pdh.dll")
procPdhOpenQuery = pdhDLL.NewProc("PdhOpenQueryW")
procPdhAddEnglishCounter = pdhDLL.NewProc("PdhAddEnglishCounterW")
procPdhCollectQueryData = pdhDLL.NewProc("PdhCollectQueryData")
procPdhGetFormattedCounterArray = pdhDLL.NewProc("PdhGetFormattedCounterArrayW")
procPdhCloseQuery = pdhDLL.NewProc("PdhCloseQuery")
)
const (
pdhFmtDouble = 0x00000200
pdhMoreData = 0x800007D2
pdhNoData = 0x800007D5
)
type pdhCounterValue struct {
CStatus uint32
DblVal float64
}
type pdhCounterValueItem struct {
SzName *uint16
FmtValue pdhCounterValue
}
func init() {
var item pdhCounterValueItem
if unsafe.Sizeof(item) != 24 {
panic(fmt.Sprintf("pdhCounterValueItem size %d != expected 24 on x64", unsafe.Sizeof(item)))
}
}
type pdhGpuUtil struct {
query uintptr
counter uintptr
}
// initPdhGpuUtil creates a PDH query for the GPU Engine utilization counter.
// Returns nil with an error if PDH or the counter is unavailable.
func initPdhGpuUtil() (*pdhGpuUtil, error) {
var query uintptr
if ret, _, _ := procPdhOpenQuery.Call(0, 0, uintptr(unsafe.Pointer(&query))); ret != 0 {
return nil, fmt.Errorf("PdhOpenQuery: 0x%x", ret)
}
path, _ := windows.UTF16PtrFromString(`\GPU Engine(*)\Utilization Percentage`)
var counter uintptr
if ret, _, _ := procPdhAddEnglishCounter.Call(
query, uintptr(unsafe.Pointer(path)), 0, uintptr(unsafe.Pointer(&counter)),
); ret != 0 {
procPdhCloseQuery.Call(query)
return nil, fmt.Errorf("PdhAddEnglishCounter(GPU Engine): 0x%x", ret)
}
procPdhCollectQueryData.Call(query)
return &pdhGpuUtil{query: query, counter: counter}, nil
}
// close releases the PDH query handle.
func (p *pdhGpuUtil) close() {
if p.query != 0 {
procPdhCloseQuery.Call(p.query)
p.query = 0
}
}
// collect reads the PDH counter and returns a map of adapter LUID to
// aggregated GPU utilization percentage, summed across all engine instances
// per adapter and clamped to 100%.
func (p *pdhGpuUtil) collect() map[LUID]float64 {
ret, _, _ := procPdhCollectQueryData.Call(p.query)
if ret != 0 && ret != pdhNoData {
return nil
}
var bufSize uint32
var itemCount uint32
ret, _, _ = procPdhGetFormattedCounterArray.Call(
p.counter, pdhFmtDouble,
uintptr(unsafe.Pointer(&bufSize)),
uintptr(unsafe.Pointer(&itemCount)),
0,
)
if ret != pdhMoreData || itemCount == 0 {
return nil
}
buf := make([]byte, bufSize)
ret, _, _ = procPdhGetFormattedCounterArray.Call(
p.counter, pdhFmtDouble,
uintptr(unsafe.Pointer(&bufSize)),
uintptr(unsafe.Pointer(&itemCount)),
uintptr(unsafe.Pointer(&buf[0])),
)
if ret != 0 {
return nil
}
itemSize := uint32(unsafe.Sizeof(pdhCounterValueItem{}))
result := make(map[LUID]float64)
for i := uint32(0); i < itemCount; i++ {
item := (*pdhCounterValueItem)(unsafe.Pointer(&buf[i*itemSize]))
if item.FmtValue.CStatus != 0 {
continue
}
luid, ok := parsePdhLuid(windows.UTF16PtrToString(item.SzName))
if !ok {
continue
}
result[luid] += item.FmtValue.DblVal
}
for luid := range result {
if result[luid] > 100.0 {
result[luid] = 100.0
}
}
return result
}
// parsePdhLuid extracts the adapter LUID (high and low parts) from a PDH
// GPU Engine instance name (e.g. "pid_1234_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute").
func parsePdhLuid(name string) (LUID, bool) {
idx := strings.Index(name, "luid_0x")
if idx < 0 {
return LUID{}, false
}
rest := name[idx+7:]
parts := strings.SplitN(rest, "_", 4)
if len(parts) < 3 {
return LUID{}, false
}
hp, err := strconv.ParseUint(parts[0], 16, 32)
if err != nil {
return LUID{}, false
}
lpStr := strings.TrimPrefix(parts[1], "0x")
lp, err := strconv.ParseUint(lpStr, 16, 32)
if err != nil {
return LUID{}, false
}
return LUID{LowPart: uint32(lp), HighPart: int32(hp)}, true
}
+53
View File
@@ -0,0 +1,53 @@
//go:build windows
package perf
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParsePdhLuid_Valid(t *testing.T) {
name := `pid_25312_luid_0x00000000_0x000148BF_phys_0_eng_2_engtype_Compute`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x000148BF), got.LowPart)
assert.Equal(t, int32(0x00000000), got.HighPart)
}
func TestParsePdhLuid_ValidNvidia(t *testing.T) {
name := `pid_1388_luid_0x00000000_0x00011372_phys_0_eng_8_engtype_Compute_1`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x00011372), got.LowPart)
assert.Equal(t, int32(0x00000000), got.HighPart)
}
func TestParsePdhLuid_NonZeroHighPart(t *testing.T) {
name := `pid_1234_luid_0x00000001_0x0000C85A_phys_0_eng_5_engtype_Copy`
got, ok := parsePdhLuid(name)
assert.True(t, ok)
assert.Equal(t, uint32(0x0000C85A), got.LowPart)
assert.Equal(t, int32(0x00000001), got.HighPart)
}
func TestParsePdhLuid_InvalidNoLuid(t *testing.T) {
_, ok := parsePdhLuid("invalid_string_without_luid")
assert.False(t, ok)
}
func TestParsePdhLuid_InvalidEmpty(t *testing.T) {
_, ok := parsePdhLuid("")
assert.False(t, ok)
}
func TestParsePdhLuid_InvalidHex(t *testing.T) {
_, ok := parsePdhLuid("pid_1234_luid_0xZZZZ_0xGGGG_phys_0")
assert.False(t, ok)
}
func TestParsePdhLuid_ShortAfterLuid(t *testing.T) {
_, ok := parsePdhLuid("pid_1234_luid_0x00000000")
assert.False(t, ok)
}
+6
View File
@@ -3,6 +3,7 @@ package scheduler
import ( import (
"fmt" "fmt"
"sort" "sort"
"strconv"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
@@ -278,6 +279,11 @@ func (s *FIFO) grantHandler(req HandlerReq, modelID string) {
s.effects.GrantError(req, shared.ConcurrencyLimitError{}) s.effects.GrantError(req, shared.ConcurrencyLimitError{})
return return
} }
if err := shared.SetReqData(req.Ctx, "fifo_priority", strconv.Itoa(s.cfg.Priority[req.Model])); err != nil {
s.logger.Debugf("failed to set fifo_priority metadata: %v", err)
}
if s.effects.GrantServe(req, modelID) { if s.effects.GrantServe(req, modelID) {
s.inFlight[modelID]++ s.inFlight[modelID]++
} }
+26 -2
View File
@@ -1,6 +1,7 @@
package scheduler package scheduler
import ( import (
"context"
"errors" "errors"
"io" "io"
"net/http" "net/http"
@@ -54,8 +55,9 @@ type stopRec struct {
// fakeEffects is an in-memory scheduler.Effects. Tests program process states // fakeEffects is an in-memory scheduler.Effects. Tests program process states
// and GrantServe outcomes, then assert on the recorded calls. // and GrantServe outcomes, then assert on the recorded calls.
type fakeEffects struct { type fakeEffects struct {
states map[string]process.ProcessState // model -> state; missing => not handled states map[string]process.ProcessState // model -> state; missing => not handled
serveResult map[string]bool // GrantServe return per model (default true) serveResult map[string]bool // GrantServe return per model (default true)
lastServeReq HandlerReq
starts []startRec starts []startRec
grants []grantRec grants []grantRec
@@ -98,6 +100,7 @@ func (f *fakeEffects) GrantServe(req HandlerReq, modelID string) bool {
if v, set := f.serveResult[modelID]; set { if v, set := f.serveResult[modelID]; set {
ok = v ok = v
} }
f.lastServeReq = req
f.grants = append(f.grants, grantRec{model: modelID, serve: ok}) f.grants = append(f.grants, grantRec{model: modelID, serve: ok})
return ok return ok
} }
@@ -169,6 +172,27 @@ func TestFIFO_FastPath(t *testing.T) {
} }
} }
func TestFIFO_GrantSetsPriorityMetadata(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateReady
cfg := config.FifoConfig{Priority: map[string]int{"a": 7}}
s := NewFIFO("test", logmon.NewWriter(io.Discard), &stubPlanner{}, cfg, nil, eff)
ctx := shared.SetContext(context.Background(), shared.ReqContextData{ModelID: "a", Metadata: make(map[string]string)})
s.OnRequest(HandlerReq{Model: "a", Ctx: ctx})
if got := eff.served("a"); got != 1 {
t.Fatalf("served(a)=%d want 1", got)
}
data, ok := shared.ReadContext(eff.lastServeReq.Ctx)
if !ok {
t.Fatal("context data missing from granted request")
}
if data.Metadata["fifo_priority"] != "7" {
t.Errorf("fifo_priority = %q, want 7", data.Metadata["fifo_priority"])
}
}
func TestFIFO_ModelNotFound(t *testing.T) { func TestFIFO_ModelNotFound(t *testing.T) {
eff := newFakeEffects() // no states => model unknown eff := newFakeEffects() // no states => model unknown
s := newFIFO(&stubPlanner{}, eff) s := newFIFO(&stubPlanner{}, eff)
+11 -3
View File
@@ -92,9 +92,14 @@ type Effects interface {
StopProcesses(timeout time.Duration, ids []string) StopProcesses(timeout time.Duration, ids []string)
} }
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured // New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured from
// from conf and bound to the given planner and effects. Currently only "fifo" // conf and bound to the given planner and effects. Supported values are "fifo"
// (the default) is supported. // (throughput-oriented, batches same-model requests) and "serial" (strict
// one-model-at-a-time, exact arrival order).
//
// The deployment default is applied by config loading (LoadConfig sets Use to
// "serial" when unset). The "" fallback here is the library default and remains
// "fifo" so callers that build a Config directly keep the original behavior.
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) { func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
use := conf.Routing.Scheduler.Use use := conf.Routing.Scheduler.Use
if use == "" { if use == "" {
@@ -103,6 +108,9 @@ func New(conf config.Config, name string, logger *logmon.Monitor, planner Swappe
switch use { switch use {
case "fifo": case "fifo":
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
case "serial":
// Serial ignores the group planner: it always evicts every other model.
return NewSerial(name, logger, eff), nil
default: default:
return nil, fmt.Errorf("unsupported scheduler type: %q", use) return nil, fmt.Errorf("unsupported scheduler type: %q", use)
} }
+253
View File
@@ -0,0 +1,253 @@
package scheduler
import (
"fmt"
"sort"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// Serial is a strict one-model-at-a-time scheduler. Unlike FIFO it never reorders
// or batches: requests run in exact arrival order and at most one request runs at
// any instant. When the next request targets a model other than the one loaded,
// every other running model is evicted and the target is loaded before it runs,
// so a single model occupies memory at a time — at the cost of throughput.
//
// Example: A B C A is served as A B C A. The final A reloads its model even
// though it ran first, because B and C displaced it in between. (FIFO, by
// contrast, would batch the two A requests: A A B C.)
//
// Serial ignores group/eviction policy entirely: it always evicts every other
// running model, regardless of how groups are configured. That is what makes the
// single-model guarantee a property of the scheduler rather than of the config.
//
// Like FIFO, every method runs on the router's single run-loop goroutine, so no
// internal locking is needed.
type Serial struct {
name string
logger *logmon.Monitor
effects Effects
// queued holds requests in strict arrival order. It is never reordered.
queued []HandlerReq
// active is the one request currently being processed (loading or serving),
// or nil when idle. phase is meaningful only while active != nil.
active *HandlerReq
phase serialPhase
}
// serialPhase is the lifecycle stage of the active request.
type serialPhase int
const (
phaseIdle serialPhase = iota
phaseSwapping // waiting for OnSwapDone for active.Model
phaseServing // waiting for OnServeDone for active.Model
)
// NewSerial builds a Serial scheduler. It takes no Swapper: eviction is always
// "stop every other running model", so the group planner is not consulted.
func NewSerial(name string, logger *logmon.Monitor, eff Effects) *Serial {
return &Serial{
name: name,
logger: logger,
effects: eff,
}
}
// OnRequest validates the model and appends the request to the tail of the queue,
// then tries to start the next job. Unknown models fail immediately.
func (s *Serial) OnRequest(req HandlerReq) {
if _, ok := s.effects.ModelState(req.Model); !ok {
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
s.effects.GrantError(req, ErrModelNotFound)
return
}
s.queued = append(s.queued, req)
broadcastQueuePositions(s.queued)
s.startNext()
}
// startNext begins processing the head of the queue when nothing is active. It
// fast-paths a request whose model is already the sole loaded-and-ready process;
// otherwise it launches a swap that evicts every other running model first. The
// loop skips over requests for models that vanished (e.g. a config reload) and
// requests whose caller disconnected before they could be served.
func (s *Serial) startNext() {
if s.active != nil {
return // a job is already loading or serving
}
for len(s.queued) > 0 {
req := s.queued[0]
s.queued = s.queued[1:]
broadcastQueuePositions(s.queued)
state, ok := s.effects.ModelState(req.Model)
if !ok {
s.effects.GrantError(req, ErrModelNotFound)
continue
}
r := req
s.active = &r
evict := s.otherRunning(req.Model)
if state == process.StateReady && len(evict) == 0 {
// Already loaded and the only model running — serve immediately.
s.logger.Debugf("%s: serving model %s (already loaded)", s.name, req.Model)
if s.serve() {
return
}
continue // caller gone; pick the next request
}
s.logger.Debugf("%s: swapping to model %s, evicting %v", s.name, req.Model, evict)
s.phase = phaseSwapping
s.effects.StartSwap(req.Model, evict)
return
}
}
// serve hands the active request its tracked handler. It returns true when the
// request is now serving (await OnServeDone); false when the caller had already
// disconnected, in which case active is cleared so the next job can start.
func (s *Serial) serve() bool {
if s.effects.GrantServe(*s.active, s.active.Model) {
s.phase = phaseServing
return true
}
s.logger.Debugf("%s: caller for model %s gone before serve", s.name, s.active.Model)
s.active = nil
s.phase = phaseIdle
return false
}
// OnSwapDone fires when the load for the active request completes. On success the
// request is served; on failure its caller receives the error and the queue
// advances. A SwapDone that does not match the active load (e.g. its request was
// unloaded or cancelled mid-load) is ignored.
func (s *Serial) OnSwapDone(ev SwapDone) {
if s.active == nil || s.phase != phaseSwapping || s.active.Model != ev.ModelID {
return
}
if ev.Err != nil {
s.logger.Debugf("%s: swap for model %s failed: %v", s.name, ev.ModelID, ev.Err)
s.effects.GrantError(*s.active, ev.Err)
s.active = nil
s.phase = phaseIdle
s.startNext()
return
}
if !s.serve() {
s.startNext() // caller vanished while the model loaded; move on
}
}
// OnServeDone fires when the active request's handler returns. The slot is freed
// and the next queued request begins.
func (s *Serial) OnServeDone(ev ServeDoneEvent) {
if s.active == nil || s.phase != phaseServing {
return
}
s.active = nil
s.phase = phaseIdle
s.startNext()
}
// OnCancel removes a disconnected client's request from the queue. A request that
// is already active is left to finish: if it was loading, OnSwapDone's serve()
// will find the caller gone (GrantServe false) and advance; if it was serving,
// its handler returns normally and reaches OnServeDone.
func (s *Serial) OnCancel(req HandlerReq) {
if len(s.queued) == 0 {
return
}
kept := s.queued[:0]
removed := false
for _, q := range s.queued {
if q.Respond == req.Respond {
removed = true
continue
}
kept = append(kept, q)
}
s.queued = kept
if removed {
s.logger.Debugf("%s: cancelled request for model %s pruned from queue", s.name, req.Model)
broadcastQueuePositions(s.queued)
}
}
// OnUnload reconciles state for an unload, stops the targeted processes, and
// advances the queue. It mirrors the FIFO contract: queued requests for unloaded
// models are failed; an active *loading* request for an unloaded model is failed
// (its swap goroutine is left to finish and its SwapDone is then ignored); an
// active *serving* request is left for its handler to end when StopProcesses
// kills the upstream. The Stop is synchronous so callers of Unload can rely on
// the processes being stopped on return.
func (s *Serial) OnUnload(targets []string, timeout time.Duration) {
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
targetSet := make(map[string]bool, len(targets))
for _, id := range targets {
targetSet[id] = true
}
if s.active != nil && s.phase == phaseSwapping && targetSet[s.active.Model] {
s.effects.GrantError(*s.active, unloadErr)
s.active = nil
s.phase = phaseIdle
}
if len(s.queued) > 0 {
kept := s.queued[:0]
for _, q := range s.queued {
if targetSet[q.Model] {
s.effects.GrantError(q, unloadErr)
continue
}
kept = append(kept, q)
}
s.queued = kept
broadcastQueuePositions(s.queued)
}
s.effects.StopProcesses(timeout, targets)
// A still-serving active request advances via OnServeDone when its killed
// handler returns; only start the next job when nothing is active now.
if s.active == nil {
s.startNext()
}
}
// OnShutdown grants err to every request the scheduler still holds: an active
// loading request and all queued requests. A serving request is torn down with
// its process by the baseRouter.
func (s *Serial) OnShutdown(err error) {
if s.active != nil && s.phase == phaseSwapping {
s.effects.GrantError(*s.active, err)
s.active = nil
s.phase = phaseIdle
}
for _, q := range s.queued {
s.effects.GrantError(q, err)
}
s.queued = nil
}
// otherRunning returns every running model except target, sorted for
// deterministic eviction.
func (s *Serial) otherRunning(target string) []string {
var out []string
for id := range s.effects.RunningModels() {
if id != target {
out = append(out, id)
}
}
sort.Strings(out)
return out
}
+391
View File
@@ -0,0 +1,391 @@
package scheduler
import (
"errors"
"io"
"testing"
"time"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
)
// Serial methods all run on the router's single run-loop goroutine, so these
// tests drive them directly and synchronously, reusing fakeEffects and the
// req/reqCh helpers from fifo_test.go. A load completes via OnSwapDone and a
// served request finishes via OnServeDone — the events the run loop delivers.
func newSerial(eff Effects) *Serial {
return NewSerial("test", logmon.NewWriter(io.Discard), eff)
}
// lastStart returns the most recent StartSwap record.
func lastStart(t *testing.T, eff *fakeEffects) startRec {
t.Helper()
if len(eff.starts) == 0 {
t.Fatal("no StartSwap recorded")
}
return eff.starts[len(eff.starts)-1]
}
func sameSet(a, b []string) bool {
if len(a) != len(b) {
return false
}
m := map[string]int{}
for _, x := range a {
m[x]++
}
for _, x := range b {
m[x]--
}
for _, v := range m {
if v != 0 {
return false
}
}
return true
}
// servedOrder returns the model IDs of every successful serve grant in order.
func servedOrder(eff *fakeEffects) []string {
var out []string
for _, g := range eff.grants {
if g.err == nil && g.serve {
out = append(out, g.model)
}
}
return out
}
func TestSerial_FastPath_AlreadyLoaded(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateReady
s := newSerial(eff)
s.OnRequest(req("a"))
if got := len(eff.starts); got != 0 {
t.Errorf("StartSwap calls=%d want 0 (already loaded, no swap)", got)
}
if got := eff.served("a"); got != 1 {
t.Errorf("served(a)=%d want 1", got)
}
}
func TestSerial_ColdStart_LoadsThenServes(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(req("a"))
if got := eff.startsFor("a"); got != 1 {
t.Fatalf("StartSwap(a)=%d want 1", got)
}
if got := eff.served("a"); got != 0 {
t.Errorf("served(a)=%d want 0 before load completes", got)
}
eff.states["a"] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: "a"})
if got := eff.served("a"); got != 1 {
t.Errorf("served(a)=%d want 1 after load", got)
}
}
func TestSerial_UnknownModel(t *testing.T) {
eff := newFakeEffects() // no states => unknown
s := newSerial(eff)
s.OnRequest(req("ghost"))
if len(eff.starts) != 0 {
t.Errorf("StartSwap calls=%d want 0", len(eff.starts))
}
if eff.errored("ghost") != 1 {
t.Fatalf("errored(ghost)=%d want 1", eff.errored("ghost"))
}
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
}
}
func TestSerial_EvictsEveryOtherModel(t *testing.T) {
eff := newFakeEffects()
eff.states["x"] = process.StateReady // already running
eff.states["y"] = process.StateReady // also running (e.g. left over)
eff.states["a"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(req("a"))
st := lastStart(t, eff)
if st.model != "a" {
t.Fatalf("loading %s want a", st.model)
}
if !sameSet(st.evict, []string{"x", "y"}) {
t.Errorf("evict=%v want [x y] (serial evicts ALL other models)", st.evict)
}
}
// TestSerial_OneJobAtATime verifies a second request waits while the first is
// serving, and only starts after the first finishes.
func TestSerial_OneJobAtATime(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateReady
eff.states["b"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(req("a")) // served immediately
s.OnRequest(req("b")) // must wait — a is serving
if got := eff.startsFor("b"); got != 0 {
t.Fatalf("StartSwap(b)=%d want 0 while a is serving", got)
}
if got := eff.served("a"); got != 1 {
t.Fatalf("served(a)=%d want 1", got)
}
// a finishes -> b may now load (evicting a).
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
if got := eff.startsFor("b"); got != 1 {
t.Fatalf("StartSwap(b)=%d want 1 after a finished", got)
}
if st := lastStart(t, eff); !sameSet(st.evict, []string{"a"}) {
t.Errorf("b evict=%v want [a]", st.evict)
}
}
// TestSerial_SameModelConsecutive_NoReload verifies back-to-back requests for the
// already-loaded model run without a reload, one after another.
func TestSerial_SameModelConsecutive_NoReload(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(req("a")) // cold load
s.OnRequest(req("a")) // queued behind the first
eff.states["a"] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: "a"}) // first serves
if got := eff.served("a"); got != 1 {
t.Fatalf("served(a)=%d want 1 (one at a time)", got)
}
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // first done -> second serves
if got := eff.served("a"); got != 2 {
t.Fatalf("served(a)=%d want 2", got)
}
if got := eff.startsFor("a"); got != 1 {
t.Errorf("StartSwap(a)=%d want 1 (second request must not reload)", got)
}
}
// TestSerial_StrictArrivalOrder is the core guarantee: qwen36, qwen35, sdxl,
// qwen36 execute in EXACTLY that order with evictions between each model switch,
// including reloading qwen36 at the end even though it ran first.
func TestSerial_StrictArrivalOrder(t *testing.T) {
eff := newFakeEffects()
for _, m := range []string{"qwen36", "qwen35", "sdxl"} {
eff.states[m] = process.StateStopped
}
s := newSerial(eff)
for _, m := range []string{"qwen36", "qwen35", "sdxl", "qwen36"} {
s.OnRequest(req(m))
}
// Only the first job starts loading; the rest wait their turn.
if len(eff.starts) != 1 || eff.starts[0].model != "qwen36" {
t.Fatalf("starts=%+v want only [qwen36] loading first", eff.starts)
}
// step completes the current model's load+serve and returns control to the
// scheduler, which must start the next queued model.
step := func(model string, wantEvict []string) {
t.Helper()
st := lastStart(t, eff)
if st.model != model {
t.Fatalf("loading %q want %q", st.model, model)
}
if !sameSet(st.evict, wantEvict) {
t.Fatalf("loading %q evict=%v want %v", model, st.evict, wantEvict)
}
// Simulate the eviction + load actually happening.
for _, e := range st.evict {
eff.states[e] = process.StateStopped
}
eff.states[model] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: model})
s.OnServeDone(ServeDoneEvent{ModelID: model})
}
step("qwen36", nil) // cold load, nothing else running
step("qwen35", []string{"qwen36"}) // evict qwen36
step("sdxl", []string{"qwen35"}) // evict qwen35
step("qwen36", []string{"sdxl"}) // RELOAD qwen36, evict sdxl
want := []string{"qwen36", "qwen35", "sdxl", "qwen36"}
if got := servedOrder(eff); !sameOrder(got, want) {
t.Fatalf("serve order=%v want %v", got, want)
}
}
func sameOrder(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestSerial_SwapError_FailsCallerAndAdvances(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
eff.states["b"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(req("a"))
s.OnRequest(req("b")) // queued behind a
// a's load fails: its caller is errored and b proceeds.
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
if eff.errored("a") != 1 {
t.Fatalf("errored(a)=%d want 1", eff.errored("a"))
}
if got := eff.startsFor("b"); got != 1 {
t.Fatalf("StartSwap(b)=%d want 1 after a's load failed", got)
}
}
// TestSerial_GrantServeFalse_Advances verifies that when the active request's
// caller has disconnected by serve time, the queue advances to the next request.
func TestSerial_GrantServeFalse_Advances(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
eff.states["b"] = process.StateStopped
eff.serveResult["a"] = false // a's caller is gone by grant time
s := newSerial(eff)
s.OnRequest(req("a"))
s.OnRequest(req("b")) // queued
eff.states["a"] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails -> advance to b
if got := eff.served("a"); got != 0 {
t.Errorf("served(a)=%d want 0 (caller gone)", got)
}
if got := eff.startsFor("b"); got != 1 {
t.Fatalf("StartSwap(b)=%d want 1 (advanced after gone caller)", got)
}
}
func TestSerial_OnCancel_QueuedRequest(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
eff.states["b"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(reqCh("a")) // starts loading a
cancelled := reqCh("b")
s.OnRequest(cancelled) // queued behind a
if len(s.queued) != 1 {
t.Fatalf("queued=%d want 1", len(s.queued))
}
s.OnCancel(cancelled)
if len(s.queued) != 0 {
t.Fatalf("queued=%d want 0 after cancel", len(s.queued))
}
// a completes; b is gone, so nothing starts for it.
eff.states["a"] = process.StateReady
s.OnSwapDone(SwapDone{ModelID: "a"})
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
if got := eff.startsFor("b"); got != 0 {
t.Errorf("StartSwap(b)=%d want 0 (cancelled before its turn)", got)
}
}
func TestSerial_OnShutdown_FailsQueuedAndActiveLoad(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
eff.states["b"] = process.StateStopped
eff.states["c"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(req("a")) // active (loading)
s.OnRequest(req("b")) // queued
s.OnRequest(req("c")) // queued
s.OnShutdown(errors.New("shutting down"))
if got := eff.errored(""); got != 3 {
t.Errorf("error grants=%d want 3 (active load + 2 queued)", got)
}
if len(s.queued) != 0 {
t.Errorf("queued=%d want 0 after shutdown", len(s.queued))
}
}
// TestSerial_OnUnload_WhileServing verifies that unloading the model that is
// actively serving does not strand the queue: OnUnload stops the process but
// leaves the active request to end via OnServeDone, which then advances.
func TestSerial_OnUnload_WhileServing(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateReady
eff.states["b"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(req("a")) // served immediately (a ready)
s.OnRequest(req("b")) // queued behind a
if got := eff.served("a"); got != 1 {
t.Fatalf("served(a)=%d want 1", got)
}
// Unload a while it is serving: the process is stopped, but the queue must
// not advance yet — the active serve is still outstanding.
s.OnUnload([]string{"a"}, time.Second)
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
}
if got := eff.startsFor("b"); got != 0 {
t.Fatalf("StartSwap(b)=%d want 0 before the serving request ends", got)
}
// The killed handler returns -> OnServeDone advances to b.
eff.states["a"] = process.StateStopped
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
if got := eff.startsFor("b"); got != 1 {
t.Fatalf("StartSwap(b)=%d want 1 after the serving request ended", got)
}
}
func TestSerial_OnUnload_DropsQueuedAndStops(t *testing.T) {
eff := newFakeEffects()
eff.states["a"] = process.StateStopped
eff.states["b"] = process.StateStopped
s := newSerial(eff)
s.OnRequest(req("a")) // active (loading a)
s.OnRequest(req("b")) // queued
// Unload a: its active load is failed and a is stopped.
s.OnUnload([]string{"a"}, time.Second)
if eff.errored("a") != 1 {
t.Errorf("errored(a)=%d want 1 (active load failed)", eff.errored("a"))
}
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
}
// b was queued and not unloaded; with a's load cancelled it now starts.
if got := eff.startsFor("b"); got != 1 {
t.Errorf("StartSwap(b)=%d want 1 after unload advanced the queue", got)
}
}
+27 -29
View File
@@ -2,6 +2,7 @@ package server
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"sort" "sort"
"strings" "strings"
@@ -9,6 +10,7 @@ import (
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/event" "github.com/mostlygeek/llama-swap/internal/event"
"github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/shared" "github.com/mostlygeek/llama-swap/internal/shared"
) )
@@ -271,7 +273,7 @@ func (s *Server) startPreload() {
if err != nil { if err != nil {
continue continue
} }
req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID})) req = req.WithContext(shared.SetContext(req.Context(), shared.ReqContextData{Model: modelID, ModelID: modelID, Metadata: make(map[string]string)}))
dw := &discardResponseWriter{status: http.StatusOK} dw := &discardResponseWriter{status: http.StatusOK}
s.local.ServeHTTP(dw, req) s.local.ServeHTTP(dw, req)
@@ -314,7 +316,7 @@ func handleUpstreamRedirect(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) { func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
upstreamPath := r.PathValue("upstreamPath") upstreamPath := r.PathValue("upstreamPath")
searchName, modelID, remainingPath, found := findModelInPath(s.cfg, "/"+upstreamPath) searchName, modelID, remainingPath, found := shared.FindModelInPath(s.cfg, "/"+upstreamPath)
if !found { if !found {
shared.SendResponse(w, r, http.StatusNotFound, "model not found") shared.SendResponse(w, r, http.StatusNotFound, "model not found")
return return
@@ -338,7 +340,29 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
// Strip the /upstream/<model> prefix before forwarding. // Strip the /upstream/<model> prefix before forwarding.
r.URL.Path = remainingPath r.URL.Path = remainingPath
// Pin the resolved model so the router skips body/query extraction. // Pin the resolved model so the router skips body/query extraction.
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID})) *r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
// If the path matches an upstream.ignorePaths entry and the model is
// not already loaded, refuse the request without triggering a swap. The
// server was not able to process the response because the model was not
// already loaded.
for _, re := range s.cfg.Upstream.IgnorePaths {
if !re.MatchString(remainingPath) {
continue
}
if s.local.Handles(modelID) {
state, ok := s.local.RunningModels()[modelID]
if !ok || state != process.StateReady {
shared.SendResponse(w, r, http.StatusConflict,
fmt.Sprintf("model %s is not loaded; path matches upstream.ignorePaths", modelID))
return
}
}
// Either the model is already loaded (no swap would be triggered)
// or this is a peer model (peer proxying never swaps). Fall through
// to normal dispatch.
break
}
switch { switch {
case s.local.Handles(modelID): case s.local.Handles(modelID):
@@ -349,29 +373,3 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID) shared.SendResponse(w, r, http.StatusNotFound, "no router for model "+modelID)
} }
} }
// findModelInPath walks a slash-separated path, building up segments until one
// matches a configured model. This resolves model names that contain slashes
// (e.g. "author/model"). Returns the matched name, its real model ID, the
// remaining path, and whether a match was found.
func findModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
name := ""
for i, part := range parts {
if part == "" {
continue
}
if name == "" {
name = part
} else {
name = name + "/" + part
}
if modelID, ok := cfg.RealModelName(name); ok {
return name, modelID, "/" + strings.Join(parts[i+1:], "/"), true
}
}
return "", "", "", false
}
+169 -2
View File
@@ -2,11 +2,17 @@ package server
import ( import (
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"regexp"
"strings"
"testing" "testing"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/process"
"github.com/mostlygeek/llama-swap/internal/shared"
) )
func TestServer_HandleListModels(t *testing.T) { func TestServer_HandleListModels(t *testing.T) {
@@ -78,6 +84,7 @@ func TestServer_HandleListModels_Aliases(t *testing.T) {
func TestServer_FindModelInPath(t *testing.T) { func TestServer_FindModelInPath(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{ cfg := config.Config{Models: map[string]config.ModelConfig{
"author": {},
"author/model": {}, "author/model": {},
"simple": {}, "simple": {},
}} }}
@@ -91,13 +98,14 @@ func TestServer_FindModelInPath(t *testing.T) {
{"/simple/v1/chat", "simple", "/v1/chat", true}, {"/simple/v1/chat", "simple", "/v1/chat", true},
{"/author/model/v1/chat", "author/model", "/v1/chat", true}, {"/author/model/v1/chat", "author/model", "/v1/chat", true},
{"/author/model", "author/model", "/", true}, {"/author/model", "author/model", "/", true},
{"/author/v1/chat", "author", "/v1/chat", true},
{"/missing/v1", "", "", false}, {"/missing/v1", "", "", false},
{"/", "", "", false}, {"/", "", "", false},
} }
for _, c := range cases { for _, c := range cases {
name, _, rem, found := findModelInPath(cfg, c.path) name, _, rem, found := shared.FindModelInPath(cfg, c.path)
if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) { if found != c.wantFound || name != c.wantName || (found && rem != c.wantRem) {
t.Errorf("findModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)", t.Errorf("FindModelInPath(%q) = (%q,%q,%v), want (%q,%q,%v)",
c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound) c.path, name, rem, found, c.wantName, c.wantRem, c.wantFound)
} }
} }
@@ -133,6 +141,165 @@ func TestServer_HandleUpstream(t *testing.T) {
}) })
} }
func upstreamMetricsServer(response string) *Server {
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
proxylog := logmon.NewWriter(io.Discard)
s := &Server{
cfg: cfg,
muxlog: logmon.NewWriter(io.Discard),
proxylog: proxylog,
upstreamlog: logmon.NewWriter(io.Discard),
inflight: &inflightCounter{},
metrics: newMetricsMonitor(proxylog, 10, 0),
local: newStubRouter([]string{"m1"}, response),
peer: newStubRouter(nil, ""),
}
s.routes()
return s
}
func TestServer_HandleUpstream_IgnorePaths(t *testing.T) {
// Compile a pattern that matches static asset suffixes.
pattern := regexp.MustCompile(`.*\.(js|json|css|png|gif|jpg|jpeg|txt)$`)
t.Run("matched path, model not loaded, returns 409", func(t *testing.T) {
local := newStubRouter([]string{"m1"}, "upstream-body")
// running is nil/empty: model is not in RunningModels() => not loaded.
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{
Models: map[string]config.ModelConfig{"m1": {}},
Upstream: config.UpstreamConfig{
IgnorePaths: []*regexp.Regexp{pattern},
},
}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
if w.Code != http.StatusConflict {
t.Fatalf("status = %d, want %d (body=%q)", w.Code, http.StatusConflict, w.Body.String())
}
if !strings.Contains(w.Body.String(), "not loaded") {
t.Errorf("body = %q, want it to contain 'not loaded'", w.Body.String())
}
})
t.Run("matched path, model already loaded, serves normally", func(t *testing.T) {
local := newStubRouter([]string{"m1"}, "upstream-body")
local.running = map[string]process.ProcessState{"m1": process.StateReady}
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{
Models: map[string]config.ModelConfig{"m1": {}},
Upstream: config.UpstreamConfig{
IgnorePaths: []*regexp.Regexp{pattern},
},
}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
}
})
t.Run("non-matched path, model not loaded, serves normally", func(t *testing.T) {
local := newStubRouter([]string{"m1"}, "upstream-body")
s := newTestServer(local, newStubRouter(nil, ""))
s.cfg = config.Config{
Models: map[string]config.ModelConfig{"m1": {}},
Upstream: config.UpstreamConfig{
IgnorePaths: []*regexp.Regexp{pattern},
},
}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
}
})
t.Run("matched path, peer model, serves normally", func(t *testing.T) {
// Peer routers do not appear via RunningModels on the local router;
// they should fall through to normal dispatch without 409.
local := newStubRouter(nil, "")
peer := newStubRouter([]string{"m1"}, "peer-body")
s := newTestServer(local, peer)
s.cfg = config.Config{
Models: map[string]config.ModelConfig{"m1": {}},
Upstream: config.UpstreamConfig{
IgnorePaths: []*regexp.Regexp{pattern},
},
}
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
if w.Code != http.StatusOK || w.Body.String() != "peer-body" {
t.Fatalf("status=%d body=%q, want 200 'peer-body'", w.Code, w.Body.String())
}
})
}
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
s := upstreamMetricsServer(resp)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
s.ServeHTTP(w, req)
if w.Code != http.StatusOK || w.Body.String() != resp {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
entries := s.metrics.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 metrics entry, got %d", len(entries))
}
if entries[0].Model != "m1" {
t.Errorf("model = %q, want m1", entries[0].Model)
}
if entries[0].ReqPath != "/v1/chat/completions" {
t.Errorf("req_path = %q, want /v1/chat/completions", entries[0].ReqPath)
}
if entries[0].Tokens.InputTokens != 3 || entries[0].Tokens.OutputTokens != 5 {
t.Errorf("tokens = %+v, want input=3 output=5", entries[0].Tokens)
}
}
func TestServer_HandleUpstream_MetricsSkipsUnsupportedPath(t *testing.T) {
s := upstreamMetricsServer("ok")
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/probe", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
s.ServeHTTP(w, req)
if w.Code != http.StatusOK || w.Body.String() != "ok" {
t.Fatalf("status=%d body=%q", w.Code, w.Body.String())
}
if len(s.metrics.getMetrics()) != 0 {
t.Errorf("want no metrics entries for unsupported path, got %d", len(s.metrics.getMetrics()))
}
}
func TestServer_HandleUpstream_MetricsSkipsGET(t *testing.T) {
s := upstreamMetricsServer(`{"usage":{}}`)
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
if len(s.metrics.getMetrics()) != 0 {
t.Errorf("want no metrics entries for GET upstream, got %d", len(s.metrics.getMetrics()))
}
}
func TestServer_HandleMetrics_Unavailable(t *testing.T) { func TestServer_HandleMetrics_Unavailable(t *testing.T) {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, "")) s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
+4 -1
View File
@@ -105,7 +105,9 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
// filtered to samples after the ?after=<RFC3339> timestamp. // filtered to samples after the ?after=<RFC3339> timestamp.
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
if s.perf == nil { if s.perf == nil {
shared.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(map[string]bool{"enabled": false})
return return
} }
@@ -136,6 +138,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ json.NewEncoder(w).Encode(map[string]any{
"enabled": true,
"sys_stats": sysStats, "sys_stats": sysStats,
"gpu_stats": gpuStats, "gpu_stats": gpuStats,
}) })
+1 -1
View File
@@ -76,7 +76,7 @@ func (s *Server) getLogger(logMonitorID string) (*logmon.Monitor, error) {
case "upstream": case "upstream":
return s.upstreamlog, nil return s.upstreamlog, nil
default: default:
if _, modelID, _, found := findModelInPath(s.cfg, "/"+logMonitorID); found { if _, modelID, _, found := shared.FindModelInPath(s.cfg, "/"+logMonitorID); found {
if log, ok := s.local.ProcessLogger(modelID); ok { if log, ok := s.local.ProcessLogger(modelID); ok {
return log, nil return log, nil
} }
+132 -33
View File
@@ -25,6 +25,8 @@ import (
// TokenMetrics holds token usage and performance metrics. // TokenMetrics holds token usage and performance metrics.
type TokenMetrics struct { type TokenMetrics struct {
CachedTokens int `json:"cache_tokens"` CachedTokens int `json:"cache_tokens"`
DraftTokens int `json:"draft_tokens"`
DraftAccTokens int `json:"draft_acc_tokens"`
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
PromptPerSecond float64 `json:"prompt_per_second"` PromptPerSecond float64 `json:"prompt_per_second"`
@@ -33,15 +35,17 @@ type TokenMetrics struct {
// ActivityLogEntry represents parsed token statistics from llama-server logs. // ActivityLogEntry represents parsed token statistics from llama-server logs.
type ActivityLogEntry struct { type ActivityLogEntry struct {
ID int `json:"id"` ID int `json:"id"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Model string `json:"model"` Model string `json:"model"`
ReqPath string `json:"req_path"` ReqPath string `json:"req_path"`
RespContentType string `json:"resp_content_type"` RespContentType string `json:"resp_content_type"`
RespStatusCode int `json:"resp_status_code"` RespStatusCode int `json:"resp_status_code"`
Tokens TokenMetrics `json:"tokens"` Tokens TokenMetrics `json:"tokens"`
DurationMs int `json:"duration_ms"` DurationMs int `json:"duration_ms"`
HasCapture bool `json:"has_capture"` HasCapture bool `json:"has_capture"`
ErrorMsg string `json:"error_msg,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
} }
// ActivityLogEvent carries a single activity log entry to event subscribers. // ActivityLogEvent carries a single activity log entry to event subscribers.
@@ -122,9 +126,11 @@ func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
} }
// record parses a completed response body and stores/emits an activity entry. // record parses a completed response body and stores/emits an activity entry.
// When captures are enabled, a zstd+CBOR capture is stored for successful // Successful requests store a zstd+CBOR capture (when enabled) with cf
// requests, with cf controlling which request/response parts are retained. // controlling which parts are retained. Failed (non-200) requests capture the
// reqBody and reqHeaders are the request data buffered before dispatch. // request only and set ErrorMsg to a description of the failure, so the error
// can be inspected without storing unreadable raw response bytes. reqBody and
// reqHeaders are the request data buffered before dispatch.
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) { func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
tm := ActivityLogEntry{ tm := ActivityLogEntry{
Timestamp: time.Now(), Timestamp: time.Now(),
@@ -135,6 +141,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()), DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
} }
if ctxData, ok := shared.ReadContext(r.Context()); ok && len(ctxData.Metadata) > 0 {
tm.Metadata = make(map[string]string, len(ctxData.Metadata))
for k, v := range ctxData.Metadata {
tm.Metadata[k] = v
}
}
queueAndEmit := func() { queueAndEmit := func() {
tm.ID = mp.queueMetrics(tm) tm.ID = mp.queueMetrics(tm)
mp.emitMetric(tm) mp.emitMetric(tm)
@@ -142,7 +155,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
if recorder.Status() != http.StatusOK { if recorder.Status() != http.StatusOK {
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path) mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
queueAndEmit() decoded, decErr := mp.decodeResponseBody(recorder, r.URL.Path)
tm.ErrorMsg = failedErrorMessage(recorder.Status(), decoded, decErr)
tm.ID = mp.queueMetrics(tm)
// Capture the request only; the failure is surfaced via ErrorMsg
// rather than storing the (possibly undisplayable) response body.
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf&^captureRespBody, reqBody, reqHeaders, nil)
mp.emitMetric(tm)
return return
} }
@@ -157,6 +176,7 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
decoded, err := decompressBody(body, encoding) decoded, err := decompressBody(body, encoding)
if err != nil { if err != nil {
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path) mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
tm.ErrorMsg = fmt.Sprintf("response decompression failed: %v", err)
queueAndEmit() queueAndEmit()
return return
} }
@@ -195,28 +215,99 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
} }
tm.ID = mp.queueMetrics(tm) tm.ID = mp.queueMetrics(tm)
if mp.enableCaptures { tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf, reqBody, reqHeaders, body)
capture := ReqRespCapture{ mp.emitMetric(tm)
ID: tm.ID, }
ReqPath: r.URL.Path,
ReqHeaders: reqHeaders, // storeCapture assembles a ReqRespCapture for id, honoring the captureFields
} // mask, and stores it when captures are enabled. body is the response body to
if cf&captureReqBody != 0 { // capture (already decompressed by the caller); pass nil to omit it. Returns
capture.ReqBody = reqBody // true if a capture was stored.
} func (mp *metricsMonitor) storeCapture(id int, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string, body []byte) bool {
if cf&captureRespHeaders != 0 { if !mp.enableCaptures {
capture.RespHeaders = headerMap(recorder.Header()) return false
redactHeaders(capture.RespHeaders) }
delete(capture.RespHeaders, "Content-Encoding") capture := ReqRespCapture{
} ID: id,
if cf&captureRespBody != 0 { ReqPath: r.URL.Path,
capture.RespBody = body ReqHeaders: reqHeaders,
} }
if mp.addCapture(capture) { if cf&captureReqBody != 0 {
tm.HasCapture = true capture.ReqBody = reqBody
}
if cf&captureRespHeaders != 0 {
capture.RespHeaders = headerMap(recorder.Header())
redactHeaders(capture.RespHeaders)
delete(capture.RespHeaders, "Content-Encoding")
}
if cf&captureRespBody != 0 {
capture.RespBody = body
}
return mp.addCapture(capture)
}
// decodeResponseBody returns the buffered response body, decompressing it when
// the upstream set a Content-Encoding we recognize. On decompression failure it
// logs a warning and returns an error so the caller can record a description
// (via ErrorMsg) instead of storing unreadable raw bytes.
func (mp *metricsMonitor) decodeResponseBody(recorder *responseBodyCopier, path string) ([]byte, error) {
body := recorder.body.Bytes()
if len(body) == 0 {
return nil, nil
}
encoding := recorder.Header().Get("Content-Encoding")
if encoding == "" {
return body, nil
}
decoded, err := decompressBody(body, encoding)
if err != nil {
mp.logger.Warnf("metrics: response decompression failed: %v, path=%s", err, path)
return nil, err
}
return decoded, nil
}
// errorMessagePaths lists JSON paths where a human-readable error message can
// live across OpenAI- and llama.cpp-style error responses.
var errorMessagePaths = []string{"error.message", "error", "message", "detail"}
// extractErrorMessage pulls a human-readable error string from a JSON error
// response. Returns "" if no message is found or the body is not valid JSON.
func extractErrorMessage(body []byte) string {
if !gjson.ValidBytes(body) {
return ""
}
parsed := gjson.ParseBytes(body)
for _, path := range errorMessagePaths {
v := parsed.Get(path)
if v.Exists() && v.Type == gjson.String {
if s := strings.TrimSpace(v.String()); s != "" {
return s
}
} }
} }
mp.emitMetric(tm) return ""
}
// failedErrorMessage builds a human-readable description for a non-200 response.
// It prefers an error message parsed from the (decompressed) body and falls back
// to the HTTP status text. A non-nil decErr indicates the body could not be
// decoded, in which case the decode error is described instead.
func failedErrorMessage(status int, body []byte, decErr error) string {
const maxLen = 500
if decErr != nil {
return fmt.Sprintf("response decode failed: %v", decErr)
}
if msg := extractErrorMessage(body); msg != "" {
if len(msg) > maxLen {
msg = msg[:maxLen] + "..."
}
return msg
}
if text := http.StatusText(status); text != "" {
return fmt.Sprintf("%d %s", status, text)
}
return fmt.Sprintf("HTTP %d", status)
} }
// usagePaths lists the JSON paths where a per-event usage object can live. // usagePaths lists the JSON paths where a per-event usage object can live.
@@ -337,6 +428,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
durationMs := wallDurationMs durationMs := wallDurationMs
tokensPerSecond := -1.0 tokensPerSecond := -1.0
promptPerSecond := -1.0 promptPerSecond := -1.0
draftTokens := -1
draftAccTokens := -1
if timings.Exists() { if timings.Exists() {
inputTokens = timings.Get("prompt_n").Int() inputTokens = timings.Get("prompt_n").Int()
@@ -350,6 +443,10 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() { if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
cachedTokens = cachedValue.Int() cachedTokens = cachedValue.Int()
} }
if timings.Get("draft_n").Exists() && timings.Get("draft_n_accepted").Exists() {
draftTokens = int(timings.Get("draft_n").Int())
draftAccTokens = int(timings.Get("draft_n_accepted").Int())
}
} }
return ActivityLogEntry{ return ActivityLogEntry{
@@ -357,6 +454,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
Model: modelID, Model: modelID,
Tokens: TokenMetrics{ Tokens: TokenMetrics{
CachedTokens: int(cachedTokens), CachedTokens: int(cachedTokens),
DraftTokens: draftTokens,
DraftAccTokens: draftAccTokens,
InputTokens: int(inputTokens), InputTokens: int(inputTokens),
OutputTokens: int(outputTokens), OutputTokens: int(outputTokens),
PromptPerSecond: promptPerSecond, PromptPerSecond: promptPerSecond,
+22 -2
View File
@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"io" "io"
"net/http" "net/http"
"strings"
"github.com/mostlygeek/llama-swap/internal/chain" "github.com/mostlygeek/llama-swap/internal/chain"
"github.com/mostlygeek/llama-swap/internal/config" "github.com/mostlygeek/llama-swap/internal/config"
@@ -21,8 +22,27 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
return return
} }
// Determine the model-routed endpoint path. Regular routes are
// already meterable; /upstream/<model>/<path> is metered only when
// the remaining path matches a model-dispatched endpoint.
checkPath := r.URL.Path
if strings.HasPrefix(r.URL.Path, "/upstream/") {
var found bool
_, _, checkPath, found = shared.FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
if !found {
next.ServeHTTP(w, r)
return
}
}
if !isMetricsRecordPath(checkPath) {
next.ServeHTTP(w, r)
return
}
// Resolve the model now so downstream dispatch hits the context // Resolve the model now so downstream dispatch hits the context
// fast path; FetchContext restores the request body. // fast path; FetchContext restores the request body for regular
// routes and extracts the model from the URL for /upstream routes.
data, err := shared.FetchContext(r, cfg) data, err := shared.FetchContext(r, cfg)
if err != nil { if err != nil {
shared.SendError(w, r, shared.ErrNoModelInContext) shared.SendError(w, r, shared.ErrNoModelInContext)
@@ -31,7 +51,7 @@ func CreateMetricsMiddleware(mm *metricsMonitor, cfg config.Config) chain.Middle
// Buffer the request body/headers for capture before dispatch // Buffer the request body/headers for capture before dispatch
// consumes them. // consumes them.
cf := captureFieldsFor(r.URL.Path) cf := captureFieldsFor(checkPath)
var reqBody []byte var reqBody []byte
var reqHeaders map[string]string var reqHeaders map[string]string
if mm.enableCaptures { if mm.enableCaptures {
+237
View File
@@ -1,9 +1,16 @@
package server package server
import ( import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
"github.com/mostlygeek/llama-swap/internal/shared"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -56,6 +63,199 @@ func TestServer_ProcessStreamingResponse_NoData(t *testing.T) {
} }
} }
func TestMetricsMonitor_RecordMetadata(t *testing.T) {
mm := newMetricsMonitor(nil, 10, 0)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"usage":{}}`))
r = r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{
ModelID: "m",
Metadata: map[string]string{"client": "web", "trace": "abc"},
}))
w := httptest.NewRecorder()
copier := newBodyCopier(w)
copier.WriteHeader(http.StatusOK)
copier.Write([]byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
mm.record("m", r, copier, 0, nil, nil)
entries := mm.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 entry, got %d", len(entries))
}
if entries[0].Metadata["client"] != "web" {
t.Errorf("client = %q, want web", entries[0].Metadata["client"])
}
if entries[0].Metadata["trace"] != "abc" {
t.Errorf("trace = %q, want abc", entries[0].Metadata["trace"])
}
}
func TestMetricsMonitor_RecordFailedRequestCapture(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
reqHeaders := map[string]string{"content-type": "application/json"}
w := httptest.NewRecorder()
copier := newBodyCopier(w)
copier.Header().Set("Content-Type", "application/json")
copier.WriteHeader(http.StatusBadGateway)
copier.Write([]byte(`{"error":{"message":"model unavailable"}}`))
reqBody := []byte(`{"model":"m","messages":[]}`)
mm.record("m", r, copier, captureAll, reqBody, reqHeaders)
entries := mm.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 entry, got %d", len(entries))
}
entry := entries[0]
if entry.RespStatusCode != http.StatusBadGateway {
t.Errorf("status = %d, want %d", entry.RespStatusCode, http.StatusBadGateway)
}
if entry.ErrorMsg != "model unavailable" {
t.Errorf("error_msg = %q, want extracted message", entry.ErrorMsg)
}
if !entry.HasCapture {
t.Fatal("failed request should capture the request so it can be inspected")
}
got := mm.getCaptureByID(entry.ID)
if got == nil {
t.Fatal("capture not found")
}
if string(got.ReqBody) != `{"model":"m","messages":[]}` {
t.Errorf("req body = %q", got.ReqBody)
}
if len(got.RespBody) != 0 {
t.Errorf("resp body stored for failed request (len=%d); want none", len(got.RespBody))
}
if got.RespHeaders["Content-Type"] != "application/json" {
t.Errorf("resp Content-Type = %q", got.RespHeaders["Content-Type"])
}
}
func TestMetricsMonitor_RecordFailedRequestStatusFallback(t *testing.T) {
// Non-JSON error body: ErrorMsg falls back to the HTTP status text.
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
w := httptest.NewRecorder()
copier := newBodyCopier(w)
copier.WriteHeader(http.StatusBadGateway)
copier.Write([]byte("<html>upstream down</html>"))
mm.record("m", r, copier, captureAll, nil, nil)
entries := mm.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 entry, got %d", len(entries))
}
if entries[0].ErrorMsg != "502 Bad Gateway" {
t.Errorf("error_msg = %q, want status text", entries[0].ErrorMsg)
}
}
func TestMetricsMonitor_RecordFailedRequestCaptureDisabled(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 0) // captures disabled
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
w := httptest.NewRecorder()
copier := newBodyCopier(w)
copier.WriteHeader(http.StatusInternalServerError)
copier.Write([]byte(`{"error":"boom"}`))
mm.record("m", r, copier, captureAll, []byte("req"), nil)
entries := mm.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 entry, got %d", len(entries))
}
if entries[0].HasCapture {
t.Fatal("captures disabled, HasCapture should be false")
}
// ErrorMsg is independent of whether captures are enabled.
if entries[0].ErrorMsg != "boom" {
t.Errorf("error_msg = %q, want boom", entries[0].ErrorMsg)
}
if mm.getCaptureByID(entries[0].ID) != nil {
t.Fatal("no capture should be stored when disabled")
}
}
func TestMetricsMonitor_RecordDecompressionFailureSetsErrorMsg(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
w := httptest.NewRecorder()
copier := newBodyCopier(w)
copier.Header().Set("Content-Encoding", "gzip")
copier.WriteHeader(http.StatusOK)
copier.Write([]byte("not-really-gzip"))
mm.record("m", r, copier, captureAll, []byte("req"), nil)
entries := mm.getMetrics()
if len(entries) != 1 {
t.Fatalf("want 1 entry, got %d", len(entries))
}
if entries[0].ErrorMsg == "" {
t.Fatal("expected ErrorMsg for decompression failure")
}
// Raw bytes must not be stored when the body could not be decoded.
if entries[0].HasCapture {
t.Fatal("decompression failure should not store a capture")
}
}
func TestMetricsMonitor_DecodeResponseBody(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
// No Content-Encoding: body returned unchanged.
w := httptest.NewRecorder()
copier := newBodyCopier(w)
copier.Write([]byte("plain"))
got, err := mm.decodeResponseBody(copier, "/p")
if err != nil || string(got) != "plain" {
t.Fatalf("plain body = %q, err = %v", got, err)
}
// Bogus gzip payload: returns an error and no body (no raw bytes kept).
w2 := httptest.NewRecorder()
copier2 := newBodyCopier(w2)
copier2.Header().Set("Content-Encoding", "gzip")
copier2.Write([]byte("not-really-gzip"))
got, err = mm.decodeResponseBody(copier2, "/p")
if err == nil {
t.Fatal("expected decompression error")
}
if got != nil {
t.Errorf("expected nil body on failure, got %q", got)
}
}
func TestServer_ExtractErrorMessage(t *testing.T) {
cases := []struct {
name string
body string
want string
}{
{"openai object", `{"error":{"message":"rate limited"}}`, "rate limited"},
{"string error", `{"error":"bad request"}`, "bad request"},
{"message field", `{"message":"nope"}`, "nope"},
{"detail field", `{"detail":"oops"}`, "oops"},
{"object error ignored", `{"error":{"code":42}}`, ""},
{"no error", `{"usage":{}}`, ""},
{"invalid json", `not-json`, ""},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := extractErrorMessage([]byte(tc.body)); got != tc.want {
t.Errorf("extractErrorMessage = %q, want %q", got, tc.want)
}
})
}
}
func TestServer_ParseMetrics_Infill(t *testing.T) { func TestServer_ParseMetrics_Infill(t *testing.T) {
// /infill responses are arrays; timings live in the last element. // /infill responses are arrays; timings live in the last element.
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]` body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
@@ -72,3 +272,40 @@ func TestServer_ParseMetrics_Infill(t *testing.T) {
t.Fatalf("tokens = %+v", entry.Tokens) t.Fatalf("tokens = %+v", entry.Tokens)
} }
} }
// TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody verifies that
// an /upstream/<model>/v1/audio/speech request uses the path-specific capture
// mask (headers only) rather than falling back to captureAll.
func TestServer_MetricsMiddleware_UpstreamAudioCaptureSkipsRespBody(t *testing.T) {
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 100, 5)
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "audio/mpeg")
w.WriteHeader(http.StatusOK)
w.Write([]byte("BINARY-AUDIO-DATA"))
})
handler := CreateMetricsMiddleware(mm, cfg)(inner)
req := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/audio/speech", strings.NewReader(`{"model":"m1"}`))
handler.ServeHTTP(httptest.NewRecorder(), req)
entries := mm.getMetrics()
if len(entries) == 0 {
t.Fatal("no metrics recorded")
}
last := entries[len(entries)-1]
if !last.HasCapture {
t.Fatal("expected capture to be stored")
}
cap := mm.getCaptureByID(last.ID)
if cap == nil {
t.Fatal("capture not found")
}
if len(cap.RespBody) != 0 {
t.Errorf("RespBody stored for /upstream audio route (len=%d); want path-specific mask to skip body", len(cap.RespBody))
}
if len(cap.RespHeaders) == 0 {
t.Error("RespHeaders not stored; want captureRespHeaders mask")
}
}
+25 -2
View File
@@ -89,6 +89,27 @@ var modelGetRoutes = []string{
"/sdapi/v1/loras", "/sdapi/v1/loras",
} }
// isMetricsRecordPath reports whether path is one of the model-dispatched
// endpoints that the metrics middleware records in the activity log.
func isMetricsRecordPath(path string) bool {
for _, p := range modelPostJSONRoutes {
if p == path {
return true
}
}
for _, p := range modelPostFormRoutes {
if p == path {
return true
}
}
for _, p := range modelGetRoutes {
if p == path {
return true
}
}
return false
}
// BuildInfo carries version metadata surfaced by GET /api/version. // BuildInfo carries version metadata surfaced by GET /api/version.
type BuildInfo struct { type BuildInfo struct {
Version string Version string
@@ -219,9 +240,11 @@ func (s *Server) routes() {
mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload)) mux.Handle("GET /unload", apiChain.ThenFunc(s.handleUnload))
mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning)) mux.Handle("GET /running", apiChain.ThenFunc(s.handleRunning))
// Upstream passthrough. // Upstream passthrough. Meter only the model-dispatched endpoints that can
// produce token usage/timings.
upstreamChain := apiChain.Append(CreateMetricsMiddleware(s.metrics, s.cfg))
mux.HandleFunc("GET /upstream", handleUpstreamRedirect) mux.HandleFunc("GET /upstream", handleUpstreamRedirect)
mux.Handle("/upstream/{upstreamPath...}", apiChain.ThenFunc(s.handleUpstream)) mux.Handle("/upstream/{upstreamPath...}", upstreamChain.ThenFunc(s.handleUpstream))
// API group (API-key protected) consumed by the UI. // API group (API-key protected) consumed by the UI.
mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll)) mux.Handle("POST /api/models/unload", apiChain.ThenFunc(s.handleAPIUnloadAll))
+90 -4
View File
@@ -26,6 +26,9 @@ type ReqContextData struct {
ModelID string ModelID string
Streaming bool Streaming bool
SendLoadingState bool SendLoadingState bool
// Metadata is a request-scoped key/value bag that handlers may mutate
// while processing. The metrics middleware copies it into ActivityLogEntry.
Metadata map[string]string
} }
var ( var (
@@ -88,16 +91,24 @@ func SendResponse(w http.ResponseWriter, r *http.Request, status int, message st
w.Write(resp) w.Write(resp)
} }
// FetchContext will attempt to get the model id from the context then // FetchContext will attempt to get the model id from the context, then
// from the model body. If it extracts the model from the body it will // from an /upstream/<model> path prefix, then from the request body/query.
// store the model in the context for downstream handlers. An error // If it extracts the model it will store it in the context for downstream
// will be returned when model can not be fetch from either location. // handlers. An error will be returned when a model cannot be identified.
func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) { func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
data, ok := ReadContext(r.Context()) data, ok := ReadContext(r.Context())
if ok { if ok {
return data, nil return data, nil
} }
if strings.HasPrefix(r.URL.Path, "/upstream/") {
if data, ok := extractUpstreamContext(r, cfg); ok {
*r = *r.WithContext(SetContext(r.Context(), data))
return data, nil
}
return ReqContextData{}, ErrNoModelInContext
}
if data, err := extractContext(r); err == nil && data.Model != "" { if data, err := extractContext(r); err == nil && data.Model != "" {
realName, _ := cfg.RealModelName(data.Model) realName, _ := cfg.RealModelName(data.Model)
if realName == "" { if realName == "" {
@@ -114,6 +125,59 @@ func FetchContext(r *http.Request, cfg config.Config) (ReqContextData, error) {
return ReqContextData{}, ErrNoModelInContext return ReqContextData{}, ErrNoModelInContext
} }
// extractUpstreamContext resolves the model from an /upstream/<model>/... path.
func extractUpstreamContext(r *http.Request, cfg config.Config) (ReqContextData, bool) {
searchName, realName, _, found := FindModelInPath(cfg, strings.TrimPrefix(r.URL.Path, "/upstream"))
if !found {
return ReqContextData{}, false
}
return ReqContextData{
Model: searchName,
ModelID: realName,
ApiKey: ExtractAPIKey(r),
Streaming: r.URL.Query().Get("stream") == "true",
SendLoadingState: sendLoadingState(cfg, realName),
Metadata: make(map[string]string),
}, true
}
// sendLoadingState reports whether the configured model wants loading-state SSEs.
func sendLoadingState(cfg config.Config, modelID string) bool {
if mc, ok := cfg.Models[modelID]; ok {
return mc.SendLoadingState != nil && *mc.SendLoadingState
}
return false
}
// FindModelInPath walks a slash-separated path, building up segments until one
// matches a configured model. This resolves model names that contain slashes
// (e.g. "author/model"). Returns the matched name, its real model ID, the
// remaining path, and whether a match was found.
func FindModelInPath(cfg config.Config, path string) (searchName, realName, remainingPath string, found bool) {
parts := strings.Split(strings.TrimSpace(path), "/")
name := ""
for i, part := range parts {
if part == "" {
continue
}
if name == "" {
name = part
} else {
name = name + "/" + part
}
if modelID, ok := cfg.RealModelName(name); ok {
searchName = name
realName = modelID
remainingPath = "/" + strings.Join(parts[i+1:], "/")
found = true
}
}
return
}
func SetContext(ctx context.Context, data ReqContextData) context.Context { func SetContext(ctx context.Context, data ReqContextData) context.Context {
return context.WithValue(ctx, ReqContextKey, data) return context.WithValue(ctx, ReqContextKey, data)
} }
@@ -123,6 +187,25 @@ func ReadContext(ctx context.Context) (ReqContextData, bool) {
return data, ok return data, ok
} }
// SetReqData attaches a key/value pair to the request context's metadata map.
// The metadata map must already exist in the context's ReqContextData; callers
// should ensure FetchContext has run or initialize the map themselves.
// It returns an error for nil contexts or contexts without request data.
func SetReqData(ctx context.Context, key, value string) error {
if ctx == nil {
return fmt.Errorf("cannot set request metadata on nil context")
}
data, ok := ReadContext(ctx)
if !ok {
return fmt.Errorf("no request context data found")
}
if data.Metadata == nil {
return fmt.Errorf("no metadata map in request context")
}
data.Metadata[key] = value
return nil
}
// extractContext pulls fields from an HTTP request into a ReqContextData, // extractContext pulls fields from an HTTP request into a ReqContextData,
// returning whatever is available. For GET requests it reads query parameters. // returning whatever is available. For GET requests it reads query parameters.
// For POST requests it inspects Content-Type and parses JSON, // For POST requests it inspects Content-Type and parses JSON,
@@ -139,6 +222,7 @@ func extractContext(r *http.Request) (ReqContextData, error) {
Model: q.Get("model"), Model: q.Get("model"),
Streaming: q.Get("stream") == "true", Streaming: q.Get("stream") == "true",
ApiKey: apiKey, ApiKey: apiKey,
Metadata: make(map[string]string),
}, nil }, nil
} }
@@ -157,6 +241,7 @@ func extractContext(r *http.Request) (ReqContextData, error) {
Model: gjson.GetBytes(bodyBytes, "model").String(), Model: gjson.GetBytes(bodyBytes, "model").String(),
Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(), Streaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
ApiKey: apiKey, ApiKey: apiKey,
Metadata: make(map[string]string),
}, nil }, nil
} }
@@ -178,6 +263,7 @@ func extractContext(r *http.Request) (ReqContextData, error) {
Model: r.FormValue("model"), Model: r.FormValue("model"),
Streaming: r.FormValue("stream") == "true", Streaming: r.FormValue("stream") == "true",
ApiKey: apiKey, ApiKey: apiKey,
Metadata: make(map[string]string),
}, nil }, nil
} }
+99
View File
@@ -11,6 +11,8 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"github.com/mostlygeek/llama-swap/internal/config"
) )
func TestExtractContext_GET(t *testing.T) { func TestExtractContext_GET(t *testing.T) {
@@ -387,6 +389,38 @@ func TestExtractContext_ApiKey(t *testing.T) {
} }
} }
func TestSetReqData(t *testing.T) {
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3", Metadata: make(map[string]string)})
if err := SetReqData(ctx, "client", "web"); err != nil {
t.Fatalf("SetReqData: %v", err)
}
if err := SetReqData(ctx, "trace", "abc123"); err != nil {
t.Fatalf("SetReqData: %v", err)
}
data, ok := ReadContext(ctx)
if !ok {
t.Fatal("context data missing")
}
if data.Metadata["client"] != "web" {
t.Errorf("client = %q, want %q", data.Metadata["client"], "web")
}
if data.Metadata["trace"] != "abc123" {
t.Errorf("trace = %q, want %q", data.Metadata["trace"], "abc123")
}
}
func TestSetReqData_Errors(t *testing.T) {
if err := SetReqData(context.Background(), "k", "v"); err == nil {
t.Error("expected error when no request context data exists")
}
ctx := SetContext(context.Background(), ReqContextData{Model: "llama3", ModelID: "llama3"})
if err := SetReqData(ctx, "k", "v"); err == nil {
t.Error("expected error when metadata map is missing")
}
}
func TestServer_ExtractAPIKey(t *testing.T) { func TestServer_ExtractAPIKey(t *testing.T) {
basicHeader := func(user, pass string) string { basicHeader := func(user, pass string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
@@ -424,3 +458,68 @@ func TestServer_ExtractAPIKey(t *testing.T) {
}) })
} }
} }
func TestFetchContext_UpstreamPath(t *testing.T) {
cfg := config.Config{
Models: map[string]config.ModelConfig{
"m1": {},
"author/model": {},
"real": {Aliases: []string{"nick"}},
},
}
cases := []struct {
name string
path string
wantModel string
wantModelID string
wantErr bool
}{
{"known model", "/upstream/m1/v1/chat/completions", "m1", "m1", false},
{"model with slash", "/upstream/author/model/v1/chat", "author/model", "author/model", false},
{"unknown model", "/upstream/nope/v1/chat/completions", "", "", true},
{"bare model path", "/upstream/m1/", "m1", "m1", false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, c.path, strings.NewReader(`{}`))
data, err := FetchContext(r, cfg)
if (err != nil) != c.wantErr {
t.Fatalf("wantErr=%v got err=%v", c.wantErr, err)
}
if c.wantErr {
return
}
if data.Model != c.wantModel {
t.Errorf("model = %q, want %q", data.Model, c.wantModel)
}
if data.ModelID != c.wantModelID {
t.Errorf("modelID = %q, want %q", data.ModelID, c.wantModelID)
}
if data.Metadata == nil {
t.Error("metadata map not initialized")
}
})
}
}
func TestFetchContext_UpstreamPath_DoesNotReadBody(t *testing.T) {
cfg := config.Config{Models: map[string]config.ModelConfig{"m1": {}}}
body := `{"model":"should-not-matter"}`
r := httptest.NewRequest(http.MethodPost, "/upstream/m1/v1/chat/completions", strings.NewReader(body))
_, err := FetchContext(r, cfg)
if err != nil {
t.Fatalf("FetchContext: %v", err)
}
// The body should be untouched so the upstream handler can still read it.
got, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
if string(got) != body {
t.Errorf("body was consumed: %q", string(got))
}
}
+137
View File
@@ -0,0 +1,137 @@
package configwatcher
import (
"context"
"os"
"path/filepath"
"sort"
"strings"
"time"
)
// DirWatcher polls a directory for changes to its set of *.yml / *.yaml files.
// It fires OnChange when a file is added, removed, or has its mod time/size
// change. Like Watcher it is poll-based so it works in Docker bind-mounts and
// k8s ConfigMap projections where inotify is unreliable.
//
// The baseline poll establishes initial state and does not fire OnChange.
type DirWatcher struct {
Path string
Interval time.Duration
OnChange func()
}
// dirSnapshot is an ordered map of file name -> file state. The ordering is
// derived from sorted filenames so two snapshots compare deterministically
// regardless of readdir order. exists reflects whether the directory was
// readable at scan time; a missing directory yields exists=false.
type dirSnapshot struct {
exists bool
names []string
states map[string]snapshot
}
func newDirSnapshot() dirSnapshot {
return dirSnapshot{states: make(map[string]snapshot)}
}
// equal reports whether two snapshots describe the same file set and per-file
// state. A missing directory (exists=false) is treated as equal to any other
// missing directory regardless of cached names.
func (s dirSnapshot) equal(other dirSnapshot) bool {
if !s.exists && !other.exists {
return true
}
if s.exists != other.exists {
return false
}
if len(s.names) != len(other.names) {
return false
}
for i, n := range s.names {
if other.names[i] != n {
return false
}
}
for _, n := range s.names {
a, b := s.states[n], other.states[n]
if a.exists != b.exists || a.size != b.size || !a.modTime.Equal(b.modTime) {
return false
}
}
return true
}
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
// OnChange whenever the directory's YAML file set changes.
//
// Policy mirrors the single-file Watcher: disappearance (directory missing or
// empty) is treated as a transient rename-style write and stays quiet; the
// transition back to present-with-content fires OnChange.
func (w *DirWatcher) Run(ctx context.Context) {
interval := w.Interval
if interval <= 0 {
interval = DefaultInterval
}
prev := scanDir(w.Path)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cur := scanDir(w.Path)
// Suppress transitions involving an empty or missing directory —
// these are treated as transient rename-style writes, mirroring
// the single-file Watcher. Only present-with-content →
// present-with-content (changed) or no-content →
// present-with-content fires OnChange.
prevHasContent := prev.exists && len(prev.names) > 0
curHasContent := cur.exists && len(cur.names) > 0
if curHasContent && (!prevHasContent || !prev.equal(cur)) && w.OnChange != nil {
w.OnChange()
}
prev = cur
}
}
}
// scanDir returns a snapshot of the *.yml/*.yaml files in dir. If the
// directory cannot be read (missing, permission denied) the snapshot reports
// exists=false; the next successful scan will detect the recovery and fire
// OnChange.
func scanDir(dir string) dirSnapshot {
snap := newDirSnapshot()
entries, err := os.ReadDir(dir)
if err != nil {
return snap // exists=false
}
snap.exists = true
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
continue
}
fi, err := os.Stat(filepath.Join(dir, name))
if err != nil {
// File disappeared between ReadDir and Stat; skip it — the
// next poll will observe the removal cleanly.
continue
}
snap.names = append(snap.names, name)
snap.states[name] = snapshot{
exists: true,
modTime: fi.ModTime(),
size: fi.Size(),
}
}
sort.Strings(snap.names)
return snap
}
+199
View File
@@ -0,0 +1,199 @@
package configwatcher
import (
"context"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// startDirWatcher launches w.Run in a goroutine and returns a function that
// cancels the context and waits for Run to return.
func startDirWatcher(t *testing.T, w *DirWatcher) func() {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.Run(ctx)
close(done)
}()
return func() {
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("DirWatcher did not stop within 2s of cancel")
}
}
}
func writeYAMLInDir(t *testing.T, dir, name, content string) {
t.Helper()
require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644))
}
func TestDirWatcher_NoFireOnBaseline(t *testing.T) {
dir := t.TempDir()
writeYAMLInDir(t, dir, "a.yaml", "a")
var n int64
stop := startDirWatcher(t, &DirWatcher{
Path: dir,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 5)
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
}
func TestDirWatcher_DetectsFileAdd(t *testing.T) {
dir := t.TempDir()
writeYAMLInDir(t, dir, "a.yaml", "a")
var n int64
stop := startDirWatcher(t, &DirWatcher{
Path: dir,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
writeYAMLInDir(t, dir, "b.yaml", "b")
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is added")
}
func TestDirWatcher_DetectsFileRemoval(t *testing.T) {
dir := t.TempDir()
writeYAMLInDir(t, dir, "a.yaml", "a")
writeYAMLInDir(t, dir, "b.yaml", "b")
var n int64
stop := startDirWatcher(t, &DirWatcher{
Path: dir,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
require.NoError(t, os.Remove(filepath.Join(dir, "b.yaml")))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is removed")
}
func TestDirWatcher_DetectsModTimeChange(t *testing.T) {
dir := t.TempDir()
writeYAMLInDir(t, dir, "a.yaml", "a")
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base, base))
var n int64
stop := startDirWatcher(t, &DirWatcher{
Path: dir,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base.Add(10*time.Second), base.Add(10*time.Second)))
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
}
func TestDirWatcher_IgnoresNonYAMLFiles(t *testing.T) {
dir := t.TempDir()
writeYAMLInDir(t, dir, "a.yaml", "a")
var n int64
stop := startDirWatcher(t, &DirWatcher{
Path: dir,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
// Adding a .txt file must not fire.
require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0o644))
time.Sleep(testInterval * 4)
require.Equal(t, int64(0), atomic.LoadInt64(&n), "non-YAML files must be ignored")
// Adding a .yml file must fire.
writeYAMLInDir(t, dir, "b.yml", "b")
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire for *.yml files")
}
func TestDirWatcher_MissingDirRecovers(t *testing.T) {
dir := t.TempDir()
writeYAMLInDir(t, dir, "a.yaml", "a")
var n int64
stop := startDirWatcher(t, &DirWatcher{
Path: dir,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
// Remove the directory. No fire expected on disappearance alone.
require.NoError(t, os.RemoveAll(dir))
time.Sleep(testInterval * 3)
require.Equal(t, int64(0), atomic.LoadInt64(&n), "directory removal alone must not fire")
// Recreate the directory and a YAML file; the recovery should fire.
require.NoError(t, os.MkdirAll(dir, 0o755))
writeYAMLInDir(t, dir, "recovered.yaml", "r")
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when dir returns with content")
}
func TestDirWatcher_EmptyDirSuppressedThenRecovers(t *testing.T) {
// Present-with-content → empty (all YAML removed, dir still exists)
// must stay quiet — treated as transient per the documented policy.
// The transition back to content fires.
dir := t.TempDir()
writeYAMLInDir(t, dir, "a.yaml", "a")
var n int64
stop := startDirWatcher(t, &DirWatcher{
Path: dir,
Interval: testInterval,
OnChange: func() { atomic.AddInt64(&n, 1) },
})
defer stop()
time.Sleep(testInterval * 2)
// Remove the only YAML file. Dir still exists but is empty of YAML.
require.NoError(t, os.Remove(filepath.Join(dir, "a.yaml")))
time.Sleep(testInterval * 4)
require.Equal(t, int64(0), atomic.LoadInt64(&n), "emptying the directory must not fire")
// Add a YAML file back; transition to present-with-content fires.
writeYAMLInDir(t, dir, "c.yaml", "c")
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when content returns")
}
func TestDirWatcher_ContextCancelStopsRun(t *testing.T) {
dir := t.TempDir()
writeYAMLInDir(t, dir, "a.yaml", "a")
w := &DirWatcher{Path: dir, Interval: testInterval}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() { w.Run(ctx); close(done) }()
time.Sleep(testInterval * 2)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Run did not return within 2s of cancel")
}
}
+37 -19
View File
@@ -55,7 +55,8 @@ var logTimeFormats = map[string]string{
} }
func main() { func main() {
flagConfig := flag.String("config", "", "path to config file (required)") flagConfig := flag.String("config", "", "path to config file")
flagConfigDir := flag.String("config-dir", "", "directory of *.yml/*.yaml config files (additive to -config)")
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)") flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file") flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
flagKeyFile := flag.String("tls-key-file", "", "TLS key file") flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
@@ -68,8 +69,8 @@ func main() {
os.Exit(0) os.Exit(0)
} }
if *flagConfig == "" { if *flagConfig == "" && *flagConfigDir == "" {
slog.Error("-config is required") slog.Error("at least one of -config or -config-dir must be provided")
os.Exit(1) os.Exit(1)
} }
@@ -88,10 +89,9 @@ func main() {
} }
} }
configPath := *flagConfig cfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
cfg, err := config.LoadConfig(configPath)
if err != nil { if err != nil {
slog.Error("failed to load config", "path", configPath, "error", err) slog.Error("failed to load config", "config", *flagConfig, "config-dir", *flagConfigDir, "error", err)
os.Exit(1) os.Exit(1)
} }
@@ -187,7 +187,7 @@ func main() {
proxyLog.Info("reloading configuration") proxyLog.Info("reloading configuration")
newCfg, err := config.LoadConfig(configPath) newCfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
if err != nil { if err != nil {
proxyLog.Warnf("failed to reload config: %v", err) proxyLog.Warnf("failed to reload config: %v", err)
return return
@@ -230,19 +230,37 @@ func main() {
defer watcherCancel() defer watcherCancel()
if *flagWatchConfig { if *flagWatchConfig {
absConfigPath, err := filepath.Abs(configPath)
if err != nil {
slog.Error("watch-config: failed to resolve config path", "error", err)
os.Exit(1)
}
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)") proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
go func() {
(&configwatcher.Watcher{ if *flagConfig != "" {
Path: absConfigPath, absConfigPath, err := filepath.Abs(*flagConfig)
Interval: configwatcher.DefaultInterval, if err != nil {
OnChange: reload, slog.Error("watch-config: failed to resolve config path", "error", err)
}).Run(watcherCtx) os.Exit(1)
}() }
go func() {
(&configwatcher.Watcher{
Path: absConfigPath,
Interval: configwatcher.DefaultInterval,
OnChange: reload,
}).Run(watcherCtx)
}()
}
if *flagConfigDir != "" {
absConfigDir, err := filepath.Abs(*flagConfigDir)
if err != nil {
slog.Error("watch-config: failed to resolve config-dir path", "error", err)
os.Exit(1)
}
go func() {
(&configwatcher.DirWatcher{
Path: absConfigDir,
Interval: configwatcher.DefaultInterval,
OnChange: reload,
}).Run(watcherCtx)
}()
}
} }
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
+2 -1
View File
@@ -8,7 +8,7 @@
import Performance from "./routes/Performance.svelte"; import Performance from "./routes/Performance.svelte";
import Playground from "./routes/Playground.svelte"; import Playground from "./routes/Playground.svelte";
import PlaygroundStub from "./routes/PlaygroundStub.svelte"; import PlaygroundStub from "./routes/PlaygroundStub.svelte";
import { enableAPIEvents } from "./stores/api"; import { enableAPIEvents, checkPerformanceEnabled } from "./stores/api";
import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme"; import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme";
import { currentRoute } from "./stores/route"; import { currentRoute } from "./stores/route";
@@ -39,6 +39,7 @@
const cleanupScreenWidth = initScreenWidth(); const cleanupScreenWidth = initScreenWidth();
const cleanupSystemTheme = initSystemThemeListener(); const cleanupSystemTheme = initSystemThemeListener();
enableAPIEvents(true); enableAPIEvents(true);
checkPerformanceEnabled();
return () => { return () => {
cleanupScreenWidth(); cleanupScreenWidth();
@@ -193,7 +193,7 @@
<dialog <dialog
bind:this={dialogEl} bind:this={dialogEl}
onclose={handleDialogClose} onclose={handleDialogClose}
class="bg-surface text-txtmain rounded-lg shadow-xl max-w-4xl w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto" class="bg-surface text-txtmain rounded-lg shadow-xl max-w-[80%] w-full max-h-[90vh] p-0 backdrop:bg-black/50 m-auto"
> >
{#if capture} {#if capture}
<div class="flex flex-col max-h-[90vh]"> <div class="flex flex-col max-h-[90vh]">
+13 -10
View File
@@ -3,6 +3,7 @@
import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme"; import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme";
import { currentRoute } from "../stores/route"; import { currentRoute } from "../stores/route";
import { playgroundActivity } from "../stores/playgroundActivity"; import { playgroundActivity } from "../stores/playgroundActivity";
import { performanceEnabled } from "../stores/api";
import ConnectionStatus from "./ConnectionStatus.svelte"; import ConnectionStatus from "./ConnectionStatus.svelte";
function handleTitleChange(newTitle: string): void { function handleTitleChange(newTitle: string): void {
@@ -84,16 +85,18 @@
> >
Logs Logs
</a> </a>
<a {#if $performanceEnabled}
href="/performance" <a
use:link href="/performance"
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap" use:link
class:font-semibold={isActive("/performance", $currentRoute)} class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
class:underline={isActive("/performance", $currentRoute)} class:font-semibold={isActive("/performance", $currentRoute)}
class:underline-offset-4={isActive("/performance", $currentRoute)} class:underline={isActive("/performance", $currentRoute)}
> class:underline-offset-4={isActive("/performance", $currentRoute)}
Performance >
</a> Performance
</a>
{/if}
<button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})"> <button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})">
{#if $themeMode === "system"} {#if $themeMode === "system"}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
@@ -0,0 +1,85 @@
<script lang="ts">
import type { Snippet } from "svelte";
interface Props {
metadata: Record<string, string> | undefined;
children: Snippet;
}
let { metadata, children }: Props = $props();
let entries = $derived(Object.entries(metadata || {}));
let triggerEl: HTMLElement | undefined = $state();
let tooltipEl: HTMLDivElement | undefined = $state();
let show = $state(false);
let tooltipStyle = $state("");
function positionTooltip() {
if (!triggerEl || !tooltipEl) return;
const triggerRect = triggerEl.getBoundingClientRect();
const tooltipRect = tooltipEl.getBoundingClientRect();
const margin = 8;
const viewportWidth = window.innerWidth;
const viewportHeight = window.innerHeight;
let left = triggerRect.left;
let top = triggerRect.bottom + margin;
// Keep tooltip within horizontal viewport bounds
if (left + tooltipRect.width > viewportWidth - margin) {
left = triggerRect.right - tooltipRect.width;
}
if (left < margin) {
left = margin;
}
// Flip above trigger if it would overflow the bottom
if (top + tooltipRect.height > viewportHeight - margin) {
top = triggerRect.top - tooltipRect.height - margin;
}
tooltipStyle = `left: ${left}px; top: ${top}px; max-width: calc(100vw - ${margin * 2}px);`;
}
function onEnter() {
show = true;
requestAnimationFrame(positionTooltip);
}
function onLeave() {
show = false;
}
</script>
<span
bind:this={triggerEl}
onmouseenter={onEnter}
onmouseleave={onLeave}
onfocus={onEnter}
onblur={onLeave}
class="inline-flex"
role="button"
tabindex="0"
aria-label="Show metadata"
>
{@render children()}
</span>
{#if show && entries.length > 0}
<div
bind:this={tooltipEl}
style={tooltipStyle}
class="fixed px-3 py-2 bg-gray-900 text-white text-sm rounded-md z-50 normal-case min-w-[12rem] max-w-[24rem] shadow-lg whitespace-normal"
>
<table class="w-full text-left">
<tbody>
{#each entries as [key, value]}
<tr class="border-b border-white/10 last:border-0">
<td class="py-1 pr-3 font-medium whitespace-nowrap text-primary">{key}</td>
<td class="py-1 break-all">{value}</td>
</tr>
{/each}
</tbody>
</table>
</div>
{/if}
+4
View File
@@ -25,6 +25,8 @@ export interface Model {
export interface TokenMetrics { export interface TokenMetrics {
cache_tokens: number; cache_tokens: number;
draft_tokens: number;
draft_acc_tokens: number;
input_tokens: number; input_tokens: number;
output_tokens: number; output_tokens: number;
prompt_per_second: number; prompt_per_second: number;
@@ -41,6 +43,8 @@ export interface ActivityLogEntry {
tokens: TokenMetrics; tokens: TokenMetrics;
duration_ms: number; duration_ms: number;
has_capture: boolean; has_capture: boolean;
error_msg?: string;
metadata?: Record<string, string>;
} }
export interface ReqRespCapture { export interface ReqRespCapture {
+188 -120
View File
@@ -2,25 +2,13 @@
import { metrics, getCapture } from "../stores/api"; import { metrics, getCapture } from "../stores/api";
import ActivityStats from "../components/ActivityStats.svelte"; import ActivityStats from "../components/ActivityStats.svelte";
import Tooltip from "../components/Tooltip.svelte"; import Tooltip from "../components/Tooltip.svelte";
import MetadataTooltip from "../components/MetadataTooltip.svelte";
import CaptureDialog from "../components/CaptureDialog.svelte"; import CaptureDialog from "../components/CaptureDialog.svelte";
import { persistentStore } from "../stores/persistent"; import { persistentStore } from "../stores/persistent";
import { onMount } from "svelte"; import { onMount } from "svelte";
import type { ReqRespCapture } from "../lib/types"; import type { ReqRespCapture } from "../lib/types";
type ColumnKey = type ColumnKey = string;
| "id"
| "time"
| "model"
| "req_path"
| "resp_status_code"
| "resp_content_type"
| "cached"
| "prompt"
| "generated"
| "prompt_speed"
| "gen_speed"
| "duration"
| "capture";
interface ColumnDef { interface ColumnDef {
key: ColumnKey; key: ColumnKey;
@@ -33,26 +21,31 @@
{ key: "time", label: "Time", defaultVisible: true }, { key: "time", label: "Time", defaultVisible: true },
{ key: "model", label: "Model", defaultVisible: true }, { key: "model", label: "Model", defaultVisible: true },
{ key: "req_path", label: "Path", defaultVisible: false }, { key: "req_path", label: "Path", defaultVisible: false },
{ key: "resp_status_code", label: "Status", defaultVisible: false }, { key: "resp_status_code", label: "Status", defaultVisible: true },
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false }, { key: "resp_content_type", label: "Content-Type", defaultVisible: false },
{ key: "cached", label: "Cached", defaultVisible: true }, { key: "cached", label: "Cached", defaultVisible: true },
{ key: "prompt", label: "Prompt", defaultVisible: true }, { key: "prompt", label: "Prompt", defaultVisible: true },
{ key: "generated", label: "Generated", defaultVisible: true }, { key: "generated", label: "Generated", defaultVisible: true },
{ key: "drafted", label: "Drafted", defaultVisible: false },
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true }, { key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true }, { key: "gen_speed", label: "Gen Speed", defaultVisible: true },
{ key: "duration", label: "Duration", defaultVisible: true }, { key: "duration", label: "Duration", defaultVisible: true },
{ key: "capture", label: "Capture", defaultVisible: true }, { key: "capture", label: "Capture", defaultVisible: true },
{ key: "meta", label: "Meta", defaultVisible: false },
]; ];
const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key); const defaultVisibleKeys = columns.filter((c) => c.defaultVisible).map((c) => c.key);
const visibleColumns = persistentStore<ColumnKey[]>( const visibleColumns = persistentStore<ColumnKey[]>("activity-columns", defaultVisibleKeys);
"activity-columns", const columnOrder = persistentStore<ColumnKey[]>(
defaultVisibleKeys "activity-column-order",
columns.map((c) => c.key)
); );
let columnsMenuOpen = $state(false); let columnsMenuOpen = $state(false);
let dropdownContainer: HTMLDivElement | null = null; let dropdownContainer: HTMLDivElement | null = null;
let dragKey: ColumnKey | null = $state(null);
let dragOverKey: ColumnKey | null = $state(null);
onMount(() => { onMount(() => {
function handleKeydown(e: KeyboardEvent) { function handleKeydown(e: KeyboardEvent) {
@@ -84,10 +77,92 @@
} }
} }
function isColumnVisible(key: ColumnKey): boolean {
return $visibleColumns.includes(key);
}
function handleDragStart(e: DragEvent, key: ColumnKey) {
dragKey = key;
e.dataTransfer?.setData("text/plain", key);
if (e.dataTransfer) {
e.dataTransfer.effectAllowed = "move";
}
}
function handleDragOver(e: DragEvent, key: ColumnKey) {
e.preventDefault();
if (e.dataTransfer) {
e.dataTransfer.dropEffect = "move";
}
dragOverKey = key;
}
function handleDrop(e: DragEvent, targetKey: ColumnKey) {
e.preventDefault();
if (!dragKey || dragKey === targetKey) return;
const order = [...$columnOrder];
const fromIndex = order.indexOf(dragKey);
let toIndex = order.indexOf(targetKey);
if (fromIndex === -1 || toIndex === -1) return;
order.splice(fromIndex, 1);
if (fromIndex < toIndex) {
toIndex -= 1;
}
order.splice(toIndex, 0, dragKey);
columnOrder.set(order);
}
function handleDragEnd() {
dragKey = null;
dragOverKey = null;
}
let orderedColumns = $derived(
columns.slice().sort((a, b) => {
const aIndex = $columnOrder.indexOf(a.key);
const bIndex = $columnOrder.indexOf(b.key);
if (aIndex === -1 && bIndex === -1) return 0;
if (aIndex === -1) return 1;
if (bIndex === -1) return -1;
return aIndex - bIndex;
})
);
let activeVisibleColumns = $derived(
columns
.filter((c) => isColumnVisible(c.key))
.sort((a, b) => {
const aIndex = $columnOrder.indexOf(a.key);
const bIndex = $columnOrder.indexOf(b.key);
if (aIndex === -1 && bIndex === -1) return 0;
if (aIndex === -1) return 1;
if (bIndex === -1) return -1;
return aIndex - bIndex;
})
.map((c) => c.key)
);
let columnLabelMap = $derived(Object.fromEntries(columns.map((c) => [c.key, c.label])));
$effect(() => {
const staticKeys = new Set(columns.map((c) => c.key));
const order = $columnOrder;
const hasStale = order.some((k) => !staticKeys.has(k));
const missing = columns.filter((c) => !order.includes(c.key)).map((c) => c.key);
if (hasStale || missing.length > 0) {
const cleaned = order.filter((k) => staticKeys.has(k));
columnOrder.set([...cleaned, ...missing]);
}
});
function formatSpeed(speed: number): string { function formatSpeed(speed: number): string {
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s"; return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
} }
function formatDrafted(drafted: number, accepted: number): string {
return drafted > 0 ? (accepted * 100 / drafted).toFixed(1) + "% (" + accepted + "/" + drafted + ")" : "-";
}
function formatDuration(ms: number): string { function formatDuration(ms: number): string {
return (ms / 1000).toFixed(2) + "s"; return (ms / 1000).toFixed(2) + "s";
} }
@@ -157,22 +232,37 @@
</svg> </svg>
</button> </button>
{#if columnsMenuOpen} {#if columnsMenuOpen}
<div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]"> <div class="absolute right-0 top-full mt-1 bg-surface border border-gray-200 dark:border-white/10 rounded shadow-lg z-10 py-1 min-w-[16rem]" role="list">
<div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10"> <div class="px-3 py-2 text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400 border-b border-gray-200 dark:border-white/10" role="presentation">
Columns Columns
</div> </div>
{#each columns as col (col.key)} {#each orderedColumns as col (col.key)}
<label {@const key = col.key}
class="flex items-center gap-2 px-3 py-1.5 text-sm cursor-pointer hover:bg-secondary-hover transition-colors" <div
class="flex items-center gap-2 px-3 py-1.5 text-sm hover:bg-secondary-hover transition-colors {dragOverKey === key && dragKey !== key ? 'bg-primary/10 ring-1 ring-primary/40' : ''} {dragKey === key ? 'opacity-40' : ''}"
role="listitem"
ondragover={(e) => handleDragOver(e, key)}
ondrop={(e) => handleDrop(e, key)}
> >
<input <span
type="checkbox" class="text-txtsecondary select-none cursor-grab"
checked={$visibleColumns.includes(col.key)} draggable={true}
onchange={() => toggleColumn(col.key)} role="button"
class="rounded" tabindex="-1"
/> aria-label="Drag to reorder {col.label}"
{col.label} ondragstart={(e) => handleDragStart(e, key)}
</label> ondragend={handleDragEnd}
>⋮⋮</span>
<label class="flex items-center gap-2 flex-1 cursor-pointer">
<input
type="checkbox"
checked={isColumnVisible(key)}
onchange={() => toggleColumn(key)}
class="rounded"
/>
{col.label}
</label>
</div>
{/each} {/each}
</div> </div>
{/if} {/if}
@@ -182,112 +272,90 @@
<table class="min-w-full divide-y"> <table class="min-w-full divide-y">
<thead class="border-gray-200 dark:border-white/10"> <thead class="border-gray-200 dark:border-white/10">
<tr class="text-left text-xs uppercase tracking-wider"> <tr class="text-left text-xs uppercase tracking-wider">
{#if $visibleColumns.includes("id")} {#each activeVisibleColumns as key (key)}
<th class="px-6 py-3">ID</th>
{/if}
{#if $visibleColumns.includes("time")}
<th class="px-6 py-3">Time</th>
{/if}
{#if $visibleColumns.includes("model")}
<th class="px-6 py-3">Model</th>
{/if}
{#if $visibleColumns.includes("req_path")}
<th class="px-6 py-3">Path</th>
{/if}
{#if $visibleColumns.includes("resp_status_code")}
<th class="px-6 py-3">Status</th>
{/if}
{#if $visibleColumns.includes("resp_content_type")}
<th class="px-6 py-3">Content-Type</th>
{/if}
{#if $visibleColumns.includes("cached")}
<th class="px-6 py-3"> <th class="px-6 py-3">
Cached <Tooltip content="prompt tokens from cache" /> {#if key === "cached"}
Cached <Tooltip content="prompt tokens from cache" />
{:else if key === "prompt"}
Prompt <Tooltip content="new prompt tokens processed" />
{:else if key === "drafted"}
Drafted <Tooltip content="acceptance rate (accepted/drafted)" />
{:else}
{columnLabelMap[key] ?? key}
{/if}
</th> </th>
{/if} {/each}
{#if $visibleColumns.includes("prompt")}
<th class="px-6 py-3">
Prompt <Tooltip content="new prompt tokens processed" />
</th>
{/if}
{#if $visibleColumns.includes("generated")}
<th class="px-6 py-3">Generated</th>
{/if}
{#if $visibleColumns.includes("prompt_speed")}
<th class="px-6 py-3">Prompt Speed</th>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<th class="px-6 py-3">Gen Speed</th>
{/if}
{#if $visibleColumns.includes("duration")}
<th class="px-6 py-3">Duration</th>
{/if}
{#if $visibleColumns.includes("capture")}
<th class="px-6 py-3">Capture</th>
{/if}
</tr> </tr>
</thead> </thead>
<tbody class="divide-y"> <tbody class="divide-y">
{#if sortedMetrics.length === 0} {#if sortedMetrics.length === 0}
<tr> <tr>
<td colspan={$visibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400"> <td colspan={activeVisibleColumns.length} class="px-6 py-8 text-center text-sm text-gray-500 dark:text-gray-400">
No activity recorded No activity recorded
</td> </td>
</tr> </tr>
{:else} {:else}
{#each sortedMetrics as metric (metric.id)} {#each sortedMetrics as metric (metric.id)}
<tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10"> <tr class="whitespace-nowrap text-sm border-gray-200 dark:border-white/10">
{#if $visibleColumns.includes("id")} {#each activeVisibleColumns as key (key)}
<td class="px-4 py-4">{metric.id + 1}</td>
{/if}
{#if $visibleColumns.includes("time")}
<td class="px-6 py-4">{formatRelativeTime(metric.timestamp)}</td>
{/if}
{#if $visibleColumns.includes("model")}
<td class="px-6 py-4">{metric.model}</td>
{/if}
{#if $visibleColumns.includes("req_path")}
<td class="px-6 py-4">{metric.req_path || "-"}</td>
{/if}
{#if $visibleColumns.includes("resp_status_code")}
<td class="px-6 py-4">{metric.resp_status_code || "-"}</td>
{/if}
{#if $visibleColumns.includes("resp_content_type")}
<td class="px-6 py-4">{metric.resp_content_type || "-"}</td>
{/if}
{#if $visibleColumns.includes("cached")}
<td class="px-6 py-4">{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}</td>
{/if}
{#if $visibleColumns.includes("prompt")}
<td class="px-6 py-4">{metric.tokens.input_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("generated")}
<td class="px-6 py-4">{metric.tokens.output_tokens.toLocaleString()}</td>
{/if}
{#if $visibleColumns.includes("prompt_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.prompt_per_second)}</td>
{/if}
{#if $visibleColumns.includes("gen_speed")}
<td class="px-6 py-4">{formatSpeed(metric.tokens.tokens_per_second)}</td>
{/if}
{#if $visibleColumns.includes("duration")}
<td class="px-6 py-4">{formatDuration(metric.duration_ms)}</td>
{/if}
{#if $visibleColumns.includes("capture")}
<td class="px-6 py-4"> <td class="px-6 py-4">
{#if metric.has_capture} {#if key === "id"}
<button {metric.id + 1}
onclick={() => viewCapture(metric.id)} {:else if key === "time"}
disabled={loadingCaptureId === metric.id} {formatRelativeTime(metric.timestamp)}
class="btn btn--sm" {:else if key === "model"}
> {metric.model}
{loadingCaptureId === metric.id ? "..." : "View"} {:else if key === "req_path"}
</button> {metric.req_path || "-"}
{:else if key === "resp_status_code"}
{#if metric.error_msg}
<span class="text-red-500 dark:text-red-400 cursor-help" title={metric.error_msg}>
{metric.resp_status_code || "-"}
</span>
{:else}
{metric.resp_status_code || "-"}
{/if}
{:else if key === "resp_content_type"}
{metric.resp_content_type || "-"}
{:else if key === "cached"}
{metric.tokens.cache_tokens > 0 ? metric.tokens.cache_tokens.toLocaleString() : "-"}
{:else if key === "prompt"}
{metric.tokens.input_tokens.toLocaleString()}
{:else if key === "generated"}
{metric.tokens.output_tokens.toLocaleString()}
{:else if key === "drafted"}
{formatDrafted(metric.tokens.draft_tokens, metric.tokens.draft_acc_tokens)}
{:else if key === "prompt_speed"}
{formatSpeed(metric.tokens.prompt_per_second)}
{:else if key === "gen_speed"}
{formatSpeed(metric.tokens.tokens_per_second)}
{:else if key === "duration"}
{formatDuration(metric.duration_ms)}
{:else if key === "capture"}
{#if metric.has_capture}
<button
onclick={() => viewCapture(metric.id)}
disabled={loadingCaptureId === metric.id}
class="btn btn--sm"
>
{loadingCaptureId === metric.id ? "..." : "View"}
</button>
{:else}
<span class="text-txtsecondary">-</span>
{/if}
{:else if key === "meta"}
{#if Object.keys(metric.metadata || {}).length > 0}
<MetadataTooltip metadata={metric.metadata}>
<span class="cursor-help text-txtsecondary hover:text-txtmain">...</span>
</MetadataTooltip>
{:else}
<span class="text-txtsecondary">-</span>
{/if}
{:else} {:else}
<span class="text-txtsecondary">-</span> -
{/if} {/if}
</td> </td>
{/if} {/each}
</tr> </tr>
{/each} {/each}
{/if} {/if}
+15
View File
@@ -19,6 +19,7 @@ export const proxyLogs = writable<string>("");
export const upstreamLogs = writable<string>(""); export const upstreamLogs = writable<string>("");
export const metrics = writable<ActivityLogEntry[]>([]); export const metrics = writable<ActivityLogEntry[]>([]);
export const inFlightRequests = writable<number>(0); export const inFlightRequests = writable<number>(0);
export const performanceEnabled = writable<boolean>(false);
export const versionInfo = writable<VersionInfo>({ export const versionInfo = writable<VersionInfo>({
build_date: "unknown", build_date: "unknown",
commit: "unknown", commit: "unknown",
@@ -210,6 +211,20 @@ export async function getCapture(id: number): Promise<ReqRespCapture | null> {
} }
} }
export async function checkPerformanceEnabled(): Promise<void> {
try {
const response = await fetch("/api/performance");
if (!response.ok) {
performanceEnabled.set(false);
return;
}
const data = await response.json();
performanceEnabled.set(data.enabled);
} catch {
performanceEnabled.set(false);
}
}
export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> { export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> {
try { try {
const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance"; const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance";