Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0292c90ca1 | |||
| 617c7dc6b9 | |||
| 542b79dacf | |||
| 0a25b3bd31 | |||
| 32bc781326 | |||
| 316ad63f76 | |||
| e37077a963 | |||
| eff9b60434 | |||
| 9bcddad91b |
@@ -0,0 +1,76 @@
|
||||
name: Build CUDA image (fork)
|
||||
|
||||
# Builds this fork's llama-swap (serial scheduler + embedded UI) from source and
|
||||
# layers it on a pinned llama.cpp CUDA server base, then pushes to the Gitea
|
||||
# container registry, e.g. gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||
#
|
||||
# Requires repo secrets: REGISTRY_USER, REGISTRY_PASSWORD (push to the registry).
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
llama_swap_version:
|
||||
description: "llama-swap version label (image tag prefix)"
|
||||
required: false
|
||||
default: "v230"
|
||||
llamacpp_build:
|
||||
description: "llama.cpp CUDA server build (base image tag suffix)"
|
||||
required: false
|
||||
default: "b9821"
|
||||
# Building the build definition itself kicks off a fresh image.
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- ".gitea/workflows/build-cuda-image.yml"
|
||||
- "docker/fork-cuda.Containerfile"
|
||||
|
||||
env:
|
||||
REGISTRY: gitea.stevedudenhoeffer.com
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Compute image metadata
|
||||
id: meta
|
||||
run: |
|
||||
LS_VER="${{ inputs.llama_swap_version || 'v230' }}"
|
||||
LCPP="${{ inputs.llamacpp_build || 'b9821' }}"
|
||||
{
|
||||
echo "image=${REGISTRY}/${{ github.repository }}"
|
||||
echo "tag=${LS_VER}-cuda-${LCPP}"
|
||||
echo "base_tag=server-cuda-${LCPP}"
|
||||
echo "ls_version=${LS_VER}"
|
||||
echo "build_date=$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Gitea registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ secrets.REGISTRY_USER }}
|
||||
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: docker/fork-cuda.Containerfile
|
||||
push: true
|
||||
provenance: false
|
||||
build-args: |
|
||||
BASE_TAG=${{ steps.meta.outputs.base_tag }}
|
||||
LS_VERSION=${{ steps.meta.outputs.ls_version }}
|
||||
GIT_HASH=${{ github.sha }}
|
||||
BUILD_DATE=${{ steps.meta.outputs.build_date }}
|
||||
tags: ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}
|
||||
|
||||
- name: Summary
|
||||
run: |
|
||||
echo "Pushed ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}" >> "$GITHUB_STEP_SUMMARY"
|
||||
@@ -5,16 +5,14 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and svelt5 for UI (located in ui/)
|
||||
- typescript, vite and svelte 5 for UI (located in ui-svelte/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use the github CLI `gh` to create pull requests and work with github
|
||||
- Rules for creating pull requests:
|
||||
- keep them short and focused on changes.
|
||||
- never include a test plan
|
||||
- keep them short and focused on changes
|
||||
- skip the test plan
|
||||
- write the summary using the same style rules as commit message
|
||||
|
||||
## Testing
|
||||
@@ -30,7 +28,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
||||
### Commit message example format:
|
||||
|
||||
```
|
||||
proxy: add new feature
|
||||
internal/server: add new feature
|
||||
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
|
||||
+21
-2
@@ -572,6 +572,24 @@
|
||||
"default": {},
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
},
|
||||
"upstream": {
|
||||
"type": "object",
|
||||
"description": "Controls behaviour of the /upstream passthrough endpoint. Recommended to only use in special use cases; leaving it as the default will typically be the best experience.",
|
||||
"properties": {
|
||||
"ignorePaths": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": [
|
||||
".*\\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$"
|
||||
],
|
||||
"description": "List of RE2 compatible regular expressions. Any request to a path matching any of the regular expressions will be ignored and not trigger a swap. When not specified, defaults to a pattern matching common static-asset suffixes (.js, .json, .css, .png, .gif, .jpg, .jpeg, .ico, .txt)."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"default": {}
|
||||
},
|
||||
"routing": {
|
||||
"type": "object",
|
||||
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
||||
@@ -583,10 +601,11 @@
|
||||
"use": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"serial",
|
||||
"fifo"
|
||||
],
|
||||
"default": "fifo",
|
||||
"description": "Scheduler to use. Only 'fifo' is currently supported."
|
||||
"default": "serial",
|
||||
"description": "Scheduler to use. 'serial' (default on this fork): strict one-model-at-a-time, requests run in exact arrival order, switching models evicts every other model first. 'fifo': throughput-oriented, batches same-model requests and allows parallel/co-resident models."
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
|
||||
+25
-3
@@ -134,6 +134,18 @@ apiKeys:
|
||||
- "${env.API_KEY_1}"
|
||||
- "${env.API_KEY_2}"
|
||||
|
||||
# upstream: controls behaviour of the /upstream passthrough endpoint
|
||||
# - optional, default: empty dictionary
|
||||
# - recommended to only use in special use cases. Leaving it as the
|
||||
# default will typically be the best experience
|
||||
upstream:
|
||||
# ignorePaths: list of RE2 compatible regular expressions
|
||||
# - default: (see below)
|
||||
# - any request to a path matching any of the regular expressions
|
||||
# will be ignored and not trigger a swap
|
||||
ignorePaths:
|
||||
- '.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$'
|
||||
|
||||
# models: a dictionary of model configurations
|
||||
# - required
|
||||
# - each key is the model's ID, used in API requests
|
||||
@@ -544,11 +556,21 @@ routing:
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# scheduler: how queued requests are ordered.
|
||||
# The default and only valid scheduler is "fifo"
|
||||
# scheduler: how queued requests are ordered and run.
|
||||
# - optional, default on this fork: "serial"
|
||||
# - valid values:
|
||||
# - "serial": strict one-model-at-a-time. Requests run in exact arrival
|
||||
# order; only one request runs at a time; switching to a different model
|
||||
# evicts every other running model first so a single model occupies memory
|
||||
# at a time. This ignores group/matrix co-residency entirely. The "fifo"
|
||||
# settings below (priority) do not apply.
|
||||
# - "fifo": throughput-oriented. Same-model requests are batched to reduce
|
||||
# swaps and a model serves up to its concurrencyLimit in parallel; models
|
||||
# in non-exclusive groups can run concurrently. Requests may be reordered.
|
||||
scheduler:
|
||||
use: fifo
|
||||
use: serial
|
||||
settings:
|
||||
# fifo settings only apply when use: fifo
|
||||
fifo:
|
||||
# priority: a dictionary of model ID -> priority
|
||||
# - optional, default: empty dictionary
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
# Build a CUDA llama-swap image FROM THIS FORK's source (includes the serial
|
||||
# scheduler) and layer it on a pinned llama.cpp CUDA server base. Produces e.g.:
|
||||
# gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||
#
|
||||
# BASE_TAG selects the llama.cpp CUDA runtime + llama-server build, e.g.
|
||||
# "server-cuda-b9821". The llama-swap binary (with the embedded Svelte UI) is
|
||||
# compiled from the repo at build time, so no GitHub release is required.
|
||||
#
|
||||
# Build context is the repo root:
|
||||
# docker build -f docker/fork-cuda.Containerfile \
|
||||
# --build-arg BASE_TAG=server-cuda-b9821 -t llama-swap:v230-cuda-b9821 .
|
||||
|
||||
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda-b9821
|
||||
|
||||
# ---- Stage 1: build the Svelte UI (embedded into the binary) ----
|
||||
FROM node:22-bookworm-slim AS ui
|
||||
WORKDIR /src/ui-svelte
|
||||
# Install deps first for layer caching. .npmrc carries legacy-peer-deps=true,
|
||||
# which the project relies on (tailwind/vite peer ranges), so copy it before
|
||||
# npm ci or the strict resolver fails with ERESOLVE.
|
||||
COPY ui-svelte/package.json ui-svelte/package-lock.json ui-svelte/.npmrc ./
|
||||
RUN npm ci
|
||||
COPY ui-svelte/ ./
|
||||
# `npm run build` is `vite build --emptyOutDir`; vite.config.ts writes to
|
||||
# ../internal/server/ui_dist, which //go:embed picks up in the next stage.
|
||||
RUN mkdir -p /src/internal/server && npm run build
|
||||
|
||||
# ---- Stage 2: build the llama-swap binary with the embedded UI ----
|
||||
FROM golang:1.26-bookworm AS build
|
||||
WORKDIR /src
|
||||
# Cache modules independently of source churn.
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Overlay the freshly built UI so //go:embed ui_dist ships the real assets
|
||||
# instead of the committed placeholder.
|
||||
COPY --from=ui /src/internal/server/ui_dist/ ./internal/server/ui_dist/
|
||||
ARG LS_VERSION=v230
|
||||
ARG GIT_HASH=unknown
|
||||
ARG BUILD_DATE=unknown
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags="-X main.version=${LS_VERSION} -X main.commit=${GIT_HASH} -X main.date=${BUILD_DATE}" \
|
||||
-o /out/llama-swap .
|
||||
|
||||
# ---- Stage 3: runtime image on the pinned llama.cpp CUDA base ----
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# Run as root by default to match the upstream `vNNN-cuda-bNNNN` (non-suffixed)
|
||||
# image that ragnaros pulls today: it needs root to reach the mounted docker
|
||||
# socket for container-backed models (sd-server). Override UID/GID at build time
|
||||
# for a non-root variant.
|
||||
ARG UID=0
|
||||
ARG GID=0
|
||||
ARG USER_HOME=/root
|
||||
ENV HOME=$USER_HOME
|
||||
|
||||
RUN set -eux; \
|
||||
if [ "$UID" -ne 0 ]; then \
|
||||
if [ "$GID" -ne 0 ]; then groupadd --system --gid "$GID" app; fi; \
|
||||
useradd --system --uid "$UID" --gid "$GID" --home "$USER_HOME" app; \
|
||||
fi; \
|
||||
mkdir --parents "$HOME" /app; \
|
||||
chown --recursive "$UID:$GID" "$HOME" /app
|
||||
|
||||
COPY --from=build --chown=$UID:$GID /out/llama-swap /app/llama-swap
|
||||
COPY --chown=$UID:$GID docker/config.example.yaml /app/config.yaml
|
||||
|
||||
USER $UID:$GID
|
||||
WORKDIR /app
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
@@ -0,0 +1,57 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
)
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
+3
-661
@@ -2,16 +2,9 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -85,12 +78,6 @@ type GroupConfig struct {
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
@@ -163,6 +150,9 @@ type Config struct {
|
||||
|
||||
// support remote peers, see issue #433, #296
|
||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||
|
||||
// upstream controls behaviour of the /upstream passthrough endpoint
|
||||
Upstream UpstreamConfig `yaml:"upstream"`
|
||||
}
|
||||
|
||||
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
||||
@@ -221,424 +211,6 @@ func LoadConfig(path string) (Config, error) {
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Apply defaults for performance config when section is missing
|
||||
if config.Performance.Every == 0 {
|
||||
config.Performance.Every = 5 * time.Second
|
||||
}
|
||||
if err = config.Performance.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("performance: %w", err)
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||
// the new `routing.router` block are mutually exclusive: a config may use
|
||||
// either style, never both.
|
||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||
rtr := config.Routing.Router
|
||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||
|
||||
if hasTopLevel && hasRouting {
|
||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||
}
|
||||
|
||||
if !hasTopLevel {
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active, so there is no conflict.
|
||||
rs := config.Routing.Router.Settings
|
||||
switch config.Routing.Router.Use {
|
||||
case "matrix":
|
||||
if rs.Matrix == nil {
|
||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||
}
|
||||
config.Matrix = rs.Matrix
|
||||
case "group", "":
|
||||
config.Groups = rs.Groups
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||
}
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.Matrix.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||
// and new-style configs converge here. The Matrix pointer is shared so
|
||||
// ExpandedSets stays in one place.
|
||||
if config.Matrix != nil {
|
||||
config.Routing.Router.Use = "matrix"
|
||||
} else {
|
||||
config.Routing.Router.Use = "group"
|
||||
}
|
||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||
config.Routing.Router.Settings.Groups = config.Groups
|
||||
|
||||
if config.Routing.Scheduler.Use == "" {
|
||||
config.Routing.Scheduler.Use = "fifo"
|
||||
}
|
||||
if config.Routing.Scheduler.Use != "fifo" {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
||||
}
|
||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||
if _, found := config.RealModelName(modelID); !found {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
@@ -683,233 +255,3 @@ func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -266,6 +266,9 @@ groups:
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: expectedGroups,
|
||||
Upstream: UpstreamConfig{
|
||||
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||
},
|
||||
Routing: RoutingConfig{
|
||||
Router: RouterConfig{
|
||||
Use: "group",
|
||||
@@ -274,7 +277,7 @@ groups:
|
||||
},
|
||||
},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "fifo",
|
||||
Use: "serial",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||
{
|
||||
name: "blank spaces only",
|
||||
content: `apiKeys: [" "]`,
|
||||
expectedErr: "api key cannot contain spaces: ` `",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains leading space",
|
||||
content: `apiKeys: [" key123"]`,
|
||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains trailing space",
|
||||
content: `apiKeys: ["key123 "]`,
|
||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains middle space",
|
||||
content: `apiKeys: ["key 123"]`,
|
||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "space in second key reports correct index",
|
||||
content: `apiKeys: ["valid-key", "bad key"]`,
|
||||
expectedErr: "apiKeys[1]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "empty in list with valid keys",
|
||||
@@ -1567,7 +1572,7 @@ groups:
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
// default group injected for orphaned models (none here) still leaves g1
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
||||
@@ -1626,7 +1631,7 @@ func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
||||
|
||||
@@ -255,6 +255,9 @@ groups:
|
||||
"mthree": "model3",
|
||||
},
|
||||
Groups: expectedGroups,
|
||||
Upstream: UpstreamConfig{
|
||||
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||
},
|
||||
Routing: RoutingConfig{
|
||||
Router: RouterConfig{
|
||||
Use: "group",
|
||||
@@ -263,7 +266,7 @@ groups:
|
||||
},
|
||||
},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "fifo",
|
||||
Use: "serial",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,441 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Apply defaults for performance config when section is missing
|
||||
if config.Performance.Every == 0 {
|
||||
config.Performance.Every = 5 * time.Second
|
||||
}
|
||||
if err = config.Performance.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("performance: %w", err)
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
// Apply default for upstream.ignorePaths when not specified. The default
|
||||
// matches common static-asset suffixes so they do not trigger a swap.
|
||||
if len(config.Upstream.IgnorePaths) == 0 {
|
||||
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||
// the new `routing.router` block are mutually exclusive: a config may use
|
||||
// either style, never both.
|
||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||
rtr := config.Routing.Router
|
||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||
|
||||
if hasTopLevel && hasRouting {
|
||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||
}
|
||||
|
||||
if !hasTopLevel {
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active, so there is no conflict.
|
||||
rs := config.Routing.Router.Settings
|
||||
switch config.Routing.Router.Use {
|
||||
case "matrix":
|
||||
if rs.Matrix == nil {
|
||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||
}
|
||||
config.Matrix = rs.Matrix
|
||||
case "group", "":
|
||||
config.Groups = rs.Groups
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||
}
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.Matrix.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||
// and new-style configs converge here. The Matrix pointer is shared so
|
||||
// ExpandedSets stays in one place.
|
||||
if config.Matrix != nil {
|
||||
config.Routing.Router.Use = "matrix"
|
||||
} else {
|
||||
config.Routing.Router.Use = "group"
|
||||
}
|
||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||
config.Routing.Router.Settings.Groups = config.Groups
|
||||
|
||||
// This fork defaults to the "serial" scheduler: one model loaded at a time,
|
||||
// requests served in strict arrival order. Set use: fifo for the upstream
|
||||
// throughput-oriented behavior that batches same-model requests.
|
||||
if config.Routing.Scheduler.Use == "" {
|
||||
config.Routing.Scheduler.Use = "serial"
|
||||
}
|
||||
switch config.Routing.Scheduler.Use {
|
||||
case "fifo", "serial":
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo, serial)", config.Routing.Scheduler.Use)
|
||||
}
|
||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||
if _, found := config.RealModelName(modelID); !found {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// identityMapPaths is the set of dotted paths whose direct children are
|
||||
// identity-keyed maps. A child key present in two sources is a hard error;
|
||||
// such keys name discrete entities (a model, a group, a peer, etc.) and a
|
||||
// duplicate means the user has split one entity across files by mistake.
|
||||
var identityMapPaths = map[string]bool{
|
||||
"models": true,
|
||||
"groups": true,
|
||||
"profiles": true,
|
||||
"peers": true,
|
||||
"matrix": true,
|
||||
"routing.router.settings.groups": true,
|
||||
"routing.router.settings.matrix": true,
|
||||
}
|
||||
|
||||
// LoadConfigSources loads and merges configuration from -config (optional)
|
||||
// and -config-dir (optional). At least one must be provided. The -config file
|
||||
// is loaded first; *.yml/*.yaml files directly under -config-dir are then
|
||||
// merged in sorted filename order. The merged document is passed through the
|
||||
// existing LoadConfigFromReader pipeline unchanged.
|
||||
func LoadConfigSources(configPath, configDir string) (Config, error) {
|
||||
if configPath == "" && configDir == "" {
|
||||
return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided")
|
||||
}
|
||||
|
||||
var sourcePaths []string
|
||||
|
||||
if configPath != "" {
|
||||
sourcePaths = append(sourcePaths, configPath)
|
||||
}
|
||||
|
||||
if configDir != "" {
|
||||
dirFiles, err := listYAMLFiles(configDir)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err)
|
||||
}
|
||||
|
||||
if configPath != "" {
|
||||
absConfig, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to resolve -config path: %w", err)
|
||||
}
|
||||
for _, f := range dirFiles {
|
||||
absF, err := filepath.Abs(f)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err)
|
||||
}
|
||||
if absConfig == absF {
|
||||
return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sourcePaths = append(sourcePaths, dirFiles...)
|
||||
}
|
||||
|
||||
if len(sourcePaths) == 0 {
|
||||
return Config{}, fmt.Errorf("no configuration sources found")
|
||||
}
|
||||
|
||||
var merged *yaml.Node
|
||||
for _, p := range sourcePaths {
|
||||
node, err := parseSource(p)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if node == nil {
|
||||
continue // empty file
|
||||
}
|
||||
if merged == nil {
|
||||
merged = node
|
||||
continue
|
||||
}
|
||||
if err := mergeNodes(merged, node, "", p); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if merged == nil {
|
||||
// All sources were empty; run the pipeline on empty input so defaults
|
||||
// and validation still apply (e.g. startPort, performance defaults).
|
||||
return LoadConfigFromReader(strings.NewReader(""))
|
||||
}
|
||||
|
||||
out, err := yaml.Marshal(merged)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to marshal merged config: %w", err)
|
||||
}
|
||||
return LoadConfigFromReader(strings.NewReader(string(out)))
|
||||
}
|
||||
|
||||
// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by
|
||||
// filename for deterministic merge order. Subdirectories are not traversed.
|
||||
func listYAMLFiles(dir string) ([]string, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var files []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||
continue
|
||||
}
|
||||
files = append(files, filepath.Join(dir, name))
|
||||
}
|
||||
sort.Strings(files)
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// parseSource reads and parses one YAML config file into a root mapping node.
|
||||
// Returns a nil node (no error) when the file is empty or contains only
|
||||
// comments.
|
||||
//
|
||||
// Env macros (${env.VAR}) are substituted at the string level before YAML
|
||||
// parsing so that flow-style constructs like [${env.API_KEY}] parse
|
||||
// correctly — the brace would otherwise be interpreted as a flow mapping.
|
||||
func parseSource(path string) (*yaml.Node, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config %s: %w", path, err)
|
||||
}
|
||||
yamlStr, err := substituteEnvMacros(string(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config %s: %w", path, err)
|
||||
}
|
||||
var doc yaml.Node
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config %s: %w", path, err)
|
||||
}
|
||||
// yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0]
|
||||
// is the actual root. Unwrap it so callers see the real top-level node.
|
||||
root := &doc
|
||||
if root.Kind == yaml.DocumentNode && len(root.Content) > 0 {
|
||||
root = root.Content[0]
|
||||
}
|
||||
if root.Kind == 0 || root.Content == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if root.Kind != yaml.MappingNode {
|
||||
return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path)
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in
|
||||
// only one side are kept; shared keys are merged recursively under the rules
|
||||
// in mergeValue. srcPath is included in error messages to identify the file
|
||||
// that introduced the conflict.
|
||||
func mergeNodes(dst, src *yaml.Node, path, srcPath string) error {
|
||||
srcIdx := indexMapping(src)
|
||||
|
||||
// First pass: merge shared keys in place.
|
||||
for i := 0; i+1 < len(dst.Content); i += 2 {
|
||||
keyNode := dst.Content[i]
|
||||
dstVal := dst.Content[i+1]
|
||||
key := keyNode.Value
|
||||
|
||||
srcVal, ok := srcIdx[key]
|
||||
if !ok {
|
||||
continue // dst-only key, keep as-is
|
||||
}
|
||||
|
||||
childPath := joinPath(path, key)
|
||||
|
||||
if identityMapPaths[childPath] {
|
||||
// Identity-keyed map: each child key names a discrete entity
|
||||
// (a model, group, peer, ...). A shared child key is a hard
|
||||
// error; src-only children are appended in the second pass.
|
||||
if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: append src-only keys.
|
||||
dstIdx := indexMapping(dst)
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
keyNode := src.Content[i]
|
||||
srcVal := src.Content[i+1]
|
||||
key := keyNode.Value
|
||||
|
||||
if _, ok := dstIdx[key]; ok {
|
||||
continue // already merged above
|
||||
}
|
||||
keyCopy := *keyNode
|
||||
valCopy := *srcVal
|
||||
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`,
|
||||
// `groups`, `peers`). Any child key present in both sides is a duplicate
|
||||
// entity and produces an error naming the conflicting key and source file.
|
||||
// src-only keys are appended to dst.
|
||||
func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error {
|
||||
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath)
|
||||
}
|
||||
dstIdx := indexMapping(dst)
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
keyNode := src.Content[i]
|
||||
srcVal := src.Content[i+1]
|
||||
key := keyNode.Value
|
||||
if _, dup := dstIdx[key]; dup {
|
||||
return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath)
|
||||
}
|
||||
keyCopy := *keyNode
|
||||
valCopy := *srcVal
|
||||
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeValue merges srcVal into dstVal (both pointing into the parent's
|
||||
// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates;
|
||||
// Scalar+Scalar errors on value mismatch; null on either side yields to the
|
||||
// non-null side.
|
||||
func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error {
|
||||
switch {
|
||||
case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode:
|
||||
return mergeNodes(dstVal, srcVal, path, srcPath)
|
||||
|
||||
case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode:
|
||||
dstVal.Content = append(dstVal.Content, srcVal.Content...)
|
||||
return nil
|
||||
|
||||
case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode:
|
||||
if isNullScalar(dstVal) {
|
||||
*dstVal = *srcVal
|
||||
return nil
|
||||
}
|
||||
if isNullScalar(srcVal) {
|
||||
return nil
|
||||
}
|
||||
if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath)
|
||||
|
||||
case isNull(dstVal):
|
||||
*dstVal = *srcVal
|
||||
return nil
|
||||
|
||||
case isNull(srcVal):
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath)
|
||||
}
|
||||
}
|
||||
|
||||
// isNull reports whether n represents a YAML null (empty or !!null).
|
||||
func isNull(n *yaml.Node) bool {
|
||||
if n == nil || n.Kind == 0 {
|
||||
return true
|
||||
}
|
||||
return isNullScalar(n)
|
||||
}
|
||||
|
||||
func isNullScalar(n *yaml.Node) bool {
|
||||
return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == ""
|
||||
}
|
||||
|
||||
// indexMapping builds a key -> value-node index for a mapping node.
|
||||
func indexMapping(n *yaml.Node) map[string]*yaml.Node {
|
||||
idx := make(map[string]*yaml.Node, len(n.Content)/2)
|
||||
for i := 0; i+1 < len(n.Content); i += 2 {
|
||||
idx[n.Content[i].Value] = n.Content[i+1]
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func joinPath(parent, key string) string {
|
||||
if parent == "" {
|
||||
return key
|
||||
}
|
||||
return parent + "." + key
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// writeYAML writes content to a file named name inside dir. Returns the full
|
||||
// path of the written file.
|
||||
func writeYAML(t *testing.T, dir, name, content string) string {
|
||||
t.Helper()
|
||||
p := filepath.Join(dir, name)
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755))
|
||||
require.NoError(t, os.WriteFile(p, []byte(content), 0o644))
|
||||
return p
|
||||
}
|
||||
|
||||
// modelCfg builds a single-model YAML snippet indented for nesting under a
|
||||
// `models:` key. The proxy uses a fixed port so tests don't depend on
|
||||
// ${PORT} allocation.
|
||||
func modelCfg(id, cmd string) string {
|
||||
return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n"
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_NeitherProvided(t *testing.T) {
|
||||
_, err := LoadConfigSources("", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "at least one of -config or -config-dir")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ConfigOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "config.yaml", `
|
||||
models:
|
||||
`+modelCfg("model1", "echo hi")+`
|
||||
groups:
|
||||
group1:
|
||||
members: ["model1"]
|
||||
`)
|
||||
cfg, err := LoadConfigSources(cfgPath, "")
|
||||
require.NoError(t, err)
|
||||
_, id, ok := cfg.FindConfig("model1")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "model1", id)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DirOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a"))
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
for _, want := range []string{"alpha", "beta"} {
|
||||
_, _, ok := cfg.FindConfig(want)
|
||||
assert.True(t, ok, "model %s should be present", want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) {
|
||||
// -config lives outside -config-dir; both contribute models additively.
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||
cfgDir := t.TempDir()
|
||||
writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext"))
|
||||
|
||||
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||
require.NoError(t, err)
|
||||
for _, want := range []string{"base", "ext"} {
|
||||
_, _, ok := cfg.FindConfig(want)
|
||||
assert.True(t, ok, "model %s should be present after merge", want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that
|
||||
// is also a member of -config-dir is rejected.
|
||||
func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||
|
||||
_, err := LoadConfigSources(cfgPath, dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "is also present in -config-dir")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicateModelID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a"))
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b"))
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate models "dup"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicateGroupID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", `
|
||||
models:
|
||||
`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n")
|
||||
writeYAML(t, dir, "b.yaml", `
|
||||
models:
|
||||
`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n")
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate groups "g1"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicatePeer(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n"
|
||||
peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n"
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA)
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB)
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate peers "remote"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ScalarConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n")
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `conflict at "globalTTL"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n")
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, cfg.GlobalTTL)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_MacrosConcatenate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}"))
|
||||
writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
// Both macros are available globally after merge.
|
||||
low, ok := cfg.Macros.Get("LOW")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 1, low)
|
||||
high, ok := cfg.Macros.Get("HIGH")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 2, high)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n")
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", `
|
||||
models:
|
||||
`+modelCfg("m1", "echo m1")+`
|
||||
routing:
|
||||
router:
|
||||
settings:
|
||||
groups:
|
||||
groupA:
|
||||
members: [m1]
|
||||
`)
|
||||
writeYAML(t, dir, "b.yaml", `
|
||||
models:
|
||||
`+modelCfg("m2", "echo m2")+`
|
||||
routing:
|
||||
router:
|
||||
settings:
|
||||
groups:
|
||||
groupB:
|
||||
members: [m2]
|
||||
`)
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
groups := cfg.Routing.Router.Settings.Groups
|
||||
assert.Contains(t, groups, "groupA")
|
||||
assert.Contains(t, groups, "groupB")
|
||||
// default group added by pipeline for orphaned/leftover routing groups...
|
||||
// here both groups reference distinct models
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Use ${PORT} in cmd so the pipeline allocates a port and substitutes it;
|
||||
// verifies env/macro substitution runs on the merged document.
|
||||
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n")
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
m := cfg.Models["m1"]
|
||||
assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted")
|
||||
assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) {
|
||||
// Regression: flow-style lists with ${env.*} must parse. Previously
|
||||
// parseSource unmarshalled before env substitution, so the brace in
|
||||
// [${env.API_KEY}] was misread as a flow mapping and parsing failed.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n")
|
||||
writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n")
|
||||
|
||||
t.Setenv("TEST_API_KEY", "secret123")
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.RequiredAPIKeys, "secret123")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) {
|
||||
// Two files defining distinct models, scanned in z..a order by filename.
|
||||
// Determine merged result is the same regardless of how the FS returns them.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z"))
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a"))
|
||||
|
||||
const runs = 3
|
||||
for i := 0; i < runs; i++ {
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
// startPort-based allocation: first allocated model gets 5800.
|
||||
// Sorted order means amodel gets 5800, zmodel gets 5801.
|
||||
_, _, ok := cfg.FindConfig("amodel")
|
||||
assert.True(t, ok)
|
||||
_, _, ok = cfg.FindConfig("zmodel")
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgDir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||
|
||||
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.Models, "m1")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EmptyDirOnly(t *testing.T) {
|
||||
// An empty -config-dir with no -config is an error: there is nothing to
|
||||
// load and silently producing an empty config would mask the misconfig.
|
||||
cfgDir := t.TempDir()
|
||||
_, err := LoadConfigSources("", cfgDir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no configuration sources found")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) {
|
||||
// Macros defined in one file should not satisfy unknown-macro validation in
|
||||
// another — they do, because merge concats global macros before validation
|
||||
// runs. This test documents that a macro from file A is usable in file B.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy"))
|
||||
writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
m := cfg.Models["user"]
|
||||
assert.Contains(t, m.Cmd, "hello")
|
||||
assert.NotContains(t, m.Cmd, "${SHARED}")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_KindMismatchErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1"))
|
||||
writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "incompatible YAML node kinds")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_NullYieldsToValue(t *testing.T) {
|
||||
// File A: routing.router block absent (null on root for routing);
|
||||
// file B: defines routing.router.settings.groups. Merge should keep B's.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||
writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// DefaultUpstreamIgnorePathsPattern is the default regular expression applied
|
||||
// to upstream.ignorePaths when the section is empty or absent from the config.
|
||||
// It matches common static-asset suffixes so requests for .js/.css/.png/etc.
|
||||
// files do not trigger a model swap.
|
||||
const DefaultUpstreamIgnorePathsPattern = `.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$`
|
||||
|
||||
// DefaultUpstreamIgnorePaths returns the default compiled ignore paths used
|
||||
// when upstream.ignorePaths is not specified in the config. The returned slice
|
||||
// is fresh so callers may mutate it without affecting other configs.
|
||||
func DefaultUpstreamIgnorePaths() []*regexp.Regexp {
|
||||
return []*regexp.Regexp{regexp.MustCompile(DefaultUpstreamIgnorePathsPattern)}
|
||||
}
|
||||
|
||||
// UpstreamConfig controls behaviour of the /upstream passthrough endpoint.
|
||||
type UpstreamConfig struct {
|
||||
// IgnorePaths is a slice of compiled regular expressions. Any request to
|
||||
// /upstream/<model>/<path> whose remaining path matches any of these
|
||||
// expressions will be ignored and not trigger a swap. When the config
|
||||
// does not specify any patterns, DefaultUpstreamIgnorePaths is applied.
|
||||
IgnorePaths []*regexp.Regexp `yaml:"-"`
|
||||
}
|
||||
|
||||
// rawUpstreamConfig is the intermediate form used to unmarshal the YAML into
|
||||
// plain strings, which are then compiled into *regexp.Regexp.
|
||||
type rawUpstreamConfig struct {
|
||||
IgnorePaths []string `yaml:"ignorePaths"`
|
||||
}
|
||||
|
||||
// UnmarshalYAML compiles each ignorePaths entry into a *regexp.Regexp. If any
|
||||
// entry fails to compile, an error is returned.
|
||||
func (u *UpstreamConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
var raw rawUpstreamConfig
|
||||
if err := value.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
patterns := make([]*regexp.Regexp, 0, len(raw.IgnorePaths))
|
||||
for _, p := range raw.IgnorePaths {
|
||||
re, err := regexp.Compile(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upstream.ignorePaths: invalid regular expression %q: %w", p, err)
|
||||
}
|
||||
patterns = append(patterns, re)
|
||||
}
|
||||
u.IgnorePaths = patterns
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const upstreamConfigHeader = `
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --arg1 one
|
||||
proxy: "http://localhost:8080"
|
||||
`
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_DefaultWhenAbsent(t *testing.T) {
|
||||
// When upstream is not specified at all, the default pattern is applied.
|
||||
content := upstreamConfigHeader
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||
|
||||
def := cfg.Upstream.IgnorePaths[0]
|
||||
assert.IsType(t, ®exp.Regexp{}, def)
|
||||
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, def.String())
|
||||
|
||||
// The default matches common static-asset suffixes.
|
||||
assert.True(t, def.MatchString("/foo.js"))
|
||||
assert.True(t, def.MatchString("/bar/baz.json"))
|
||||
assert.True(t, def.MatchString("/static/img.png"))
|
||||
assert.True(t, def.MatchString("/notes.txt"))
|
||||
assert.True(t, def.MatchString("/favicon.ico"))
|
||||
// And does not match inference API paths.
|
||||
assert.False(t, def.MatchString("/v1/chat/completions"))
|
||||
assert.False(t, def.MatchString("/v1/models"))
|
||||
assert.False(t, def.MatchString("/health"))
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_DefaultWhenSectionEmpty(t *testing.T) {
|
||||
// When upstream is present but ignorePaths is omitted, the default is still
|
||||
// applied.
|
||||
content := `upstream: {}` + "\n" + upstreamConfigHeader
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, cfg.Upstream.IgnorePaths[0].String())
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_Compiles(t *testing.T) {
|
||||
content := `
|
||||
upstream:
|
||||
ignorePaths:
|
||||
- ".*\\.(js|json|css|png|gif|jpg|jpeg|txt)$"
|
||||
- "^/static/.*"
|
||||
` + upstreamConfigHeader
|
||||
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfg.Upstream.IgnorePaths, 2)
|
||||
|
||||
// Verify the patterns are compiled into *regexp.Regexp and match as expected.
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/foo.js"))
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/bar/baz.json"))
|
||||
assert.False(t, cfg.Upstream.IgnorePaths[0].MatchString("/v1/chat/completions"))
|
||||
assert.True(t, cfg.Upstream.IgnorePaths[1].MatchString("/static/foo.png"))
|
||||
assert.False(t, cfg.Upstream.IgnorePaths[1].MatchString("/v1/chat/completions"))
|
||||
|
||||
// Confirm the type is *regexp.Regexp to satisfy the API contract.
|
||||
for _, re := range cfg.Upstream.IgnorePaths {
|
||||
assert.IsType(t, ®exp.Regexp{}, re)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_UpstreamIgnorePaths_InvalidRegexReturnsError(t *testing.T) {
|
||||
content := `
|
||||
upstream:
|
||||
ignorePaths:
|
||||
- "[invalid("
|
||||
` + upstreamConfigHeader
|
||||
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "upstream.ignorePaths")
|
||||
assert.Contains(t, err.Error(), "invalid regular expression")
|
||||
}
|
||||
@@ -92,9 +92,14 @@ type Effects interface {
|
||||
StopProcesses(timeout time.Duration, ids []string)
|
||||
}
|
||||
|
||||
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured
|
||||
// from conf and bound to the given planner and effects. Currently only "fifo"
|
||||
// (the default) is supported.
|
||||
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured from
|
||||
// conf and bound to the given planner and effects. Supported values are "fifo"
|
||||
// (throughput-oriented, batches same-model requests) and "serial" (strict
|
||||
// one-model-at-a-time, exact arrival order).
|
||||
//
|
||||
// The deployment default is applied by config loading (LoadConfig sets Use to
|
||||
// "serial" when unset). The "" fallback here is the library default and remains
|
||||
// "fifo" so callers that build a Config directly keep the original behavior.
|
||||
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
||||
use := conf.Routing.Scheduler.Use
|
||||
if use == "" {
|
||||
@@ -103,6 +108,9 @@ func New(conf config.Config, name string, logger *logmon.Monitor, planner Swappe
|
||||
switch use {
|
||||
case "fifo":
|
||||
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
||||
case "serial":
|
||||
// Serial ignores the group planner: it always evicts every other model.
|
||||
return NewSerial(name, logger, eff), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,253 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// Serial is a strict one-model-at-a-time scheduler. Unlike FIFO it never reorders
|
||||
// or batches: requests run in exact arrival order and at most one request runs at
|
||||
// any instant. When the next request targets a model other than the one loaded,
|
||||
// every other running model is evicted and the target is loaded before it runs,
|
||||
// so a single model occupies memory at a time — at the cost of throughput.
|
||||
//
|
||||
// Example: A B C A is served as A B C A. The final A reloads its model even
|
||||
// though it ran first, because B and C displaced it in between. (FIFO, by
|
||||
// contrast, would batch the two A requests: A A B C.)
|
||||
//
|
||||
// Serial ignores group/eviction policy entirely: it always evicts every other
|
||||
// running model, regardless of how groups are configured. That is what makes the
|
||||
// single-model guarantee a property of the scheduler rather than of the config.
|
||||
//
|
||||
// Like FIFO, every method runs on the router's single run-loop goroutine, so no
|
||||
// internal locking is needed.
|
||||
type Serial struct {
|
||||
name string
|
||||
logger *logmon.Monitor
|
||||
effects Effects
|
||||
|
||||
// queued holds requests in strict arrival order. It is never reordered.
|
||||
queued []HandlerReq
|
||||
|
||||
// active is the one request currently being processed (loading or serving),
|
||||
// or nil when idle. phase is meaningful only while active != nil.
|
||||
active *HandlerReq
|
||||
phase serialPhase
|
||||
}
|
||||
|
||||
// serialPhase is the lifecycle stage of the active request.
|
||||
type serialPhase int
|
||||
|
||||
const (
|
||||
phaseIdle serialPhase = iota
|
||||
phaseSwapping // waiting for OnSwapDone for active.Model
|
||||
phaseServing // waiting for OnServeDone for active.Model
|
||||
)
|
||||
|
||||
// NewSerial builds a Serial scheduler. It takes no Swapper: eviction is always
|
||||
// "stop every other running model", so the group planner is not consulted.
|
||||
func NewSerial(name string, logger *logmon.Monitor, eff Effects) *Serial {
|
||||
return &Serial{
|
||||
name: name,
|
||||
logger: logger,
|
||||
effects: eff,
|
||||
}
|
||||
}
|
||||
|
||||
// OnRequest validates the model and appends the request to the tail of the queue,
|
||||
// then tries to start the next job. Unknown models fail immediately.
|
||||
func (s *Serial) OnRequest(req HandlerReq) {
|
||||
if _, ok := s.effects.ModelState(req.Model); !ok {
|
||||
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
return
|
||||
}
|
||||
s.queued = append(s.queued, req)
|
||||
broadcastQueuePositions(s.queued)
|
||||
s.startNext()
|
||||
}
|
||||
|
||||
// startNext begins processing the head of the queue when nothing is active. It
|
||||
// fast-paths a request whose model is already the sole loaded-and-ready process;
|
||||
// otherwise it launches a swap that evicts every other running model first. The
|
||||
// loop skips over requests for models that vanished (e.g. a config reload) and
|
||||
// requests whose caller disconnected before they could be served.
|
||||
func (s *Serial) startNext() {
|
||||
if s.active != nil {
|
||||
return // a job is already loading or serving
|
||||
}
|
||||
for len(s.queued) > 0 {
|
||||
req := s.queued[0]
|
||||
s.queued = s.queued[1:]
|
||||
broadcastQueuePositions(s.queued)
|
||||
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
continue
|
||||
}
|
||||
|
||||
r := req
|
||||
s.active = &r
|
||||
|
||||
evict := s.otherRunning(req.Model)
|
||||
if state == process.StateReady && len(evict) == 0 {
|
||||
// Already loaded and the only model running — serve immediately.
|
||||
s.logger.Debugf("%s: serving model %s (already loaded)", s.name, req.Model)
|
||||
if s.serve() {
|
||||
return
|
||||
}
|
||||
continue // caller gone; pick the next request
|
||||
}
|
||||
|
||||
s.logger.Debugf("%s: swapping to model %s, evicting %v", s.name, req.Model, evict)
|
||||
s.phase = phaseSwapping
|
||||
s.effects.StartSwap(req.Model, evict)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// serve hands the active request its tracked handler. It returns true when the
|
||||
// request is now serving (await OnServeDone); false when the caller had already
|
||||
// disconnected, in which case active is cleared so the next job can start.
|
||||
func (s *Serial) serve() bool {
|
||||
if s.effects.GrantServe(*s.active, s.active.Model) {
|
||||
s.phase = phaseServing
|
||||
return true
|
||||
}
|
||||
s.logger.Debugf("%s: caller for model %s gone before serve", s.name, s.active.Model)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
return false
|
||||
}
|
||||
|
||||
// OnSwapDone fires when the load for the active request completes. On success the
|
||||
// request is served; on failure its caller receives the error and the queue
|
||||
// advances. A SwapDone that does not match the active load (e.g. its request was
|
||||
// unloaded or cancelled mid-load) is ignored.
|
||||
func (s *Serial) OnSwapDone(ev SwapDone) {
|
||||
if s.active == nil || s.phase != phaseSwapping || s.active.Model != ev.ModelID {
|
||||
return
|
||||
}
|
||||
if ev.Err != nil {
|
||||
s.logger.Debugf("%s: swap for model %s failed: %v", s.name, ev.ModelID, ev.Err)
|
||||
s.effects.GrantError(*s.active, ev.Err)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
s.startNext()
|
||||
return
|
||||
}
|
||||
if !s.serve() {
|
||||
s.startNext() // caller vanished while the model loaded; move on
|
||||
}
|
||||
}
|
||||
|
||||
// OnServeDone fires when the active request's handler returns. The slot is freed
|
||||
// and the next queued request begins.
|
||||
func (s *Serial) OnServeDone(ev ServeDoneEvent) {
|
||||
if s.active == nil || s.phase != phaseServing {
|
||||
return
|
||||
}
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
s.startNext()
|
||||
}
|
||||
|
||||
// OnCancel removes a disconnected client's request from the queue. A request that
|
||||
// is already active is left to finish: if it was loading, OnSwapDone's serve()
|
||||
// will find the caller gone (GrantServe false) and advance; if it was serving,
|
||||
// its handler returns normally and reaches OnServeDone.
|
||||
func (s *Serial) OnCancel(req HandlerReq) {
|
||||
if len(s.queued) == 0 {
|
||||
return
|
||||
}
|
||||
kept := s.queued[:0]
|
||||
removed := false
|
||||
for _, q := range s.queued {
|
||||
if q.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
if removed {
|
||||
s.logger.Debugf("%s: cancelled request for model %s pruned from queue", s.name, req.Model)
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnload reconciles state for an unload, stops the targeted processes, and
|
||||
// advances the queue. It mirrors the FIFO contract: queued requests for unloaded
|
||||
// models are failed; an active *loading* request for an unloaded model is failed
|
||||
// (its swap goroutine is left to finish and its SwapDone is then ignored); an
|
||||
// active *serving* request is left for its handler to end when StopProcesses
|
||||
// kills the upstream. The Stop is synchronous so callers of Unload can rely on
|
||||
// the processes being stopped on return.
|
||||
func (s *Serial) OnUnload(targets []string, timeout time.Duration) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(targets))
|
||||
for _, id := range targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
if s.active != nil && s.phase == phaseSwapping && targetSet[s.active.Model] {
|
||||
s.effects.GrantError(*s.active, unloadErr)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
}
|
||||
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, q := range s.queued {
|
||||
if targetSet[q.Model] {
|
||||
s.effects.GrantError(q, unloadErr)
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
s.effects.StopProcesses(timeout, targets)
|
||||
|
||||
// A still-serving active request advances via OnServeDone when its killed
|
||||
// handler returns; only start the next job when nothing is active now.
|
||||
if s.active == nil {
|
||||
s.startNext()
|
||||
}
|
||||
}
|
||||
|
||||
// OnShutdown grants err to every request the scheduler still holds: an active
|
||||
// loading request and all queued requests. A serving request is torn down with
|
||||
// its process by the baseRouter.
|
||||
func (s *Serial) OnShutdown(err error) {
|
||||
if s.active != nil && s.phase == phaseSwapping {
|
||||
s.effects.GrantError(*s.active, err)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
}
|
||||
for _, q := range s.queued {
|
||||
s.effects.GrantError(q, err)
|
||||
}
|
||||
s.queued = nil
|
||||
}
|
||||
|
||||
// otherRunning returns every running model except target, sorted for
|
||||
// deterministic eviction.
|
||||
func (s *Serial) otherRunning(target string) []string {
|
||||
var out []string
|
||||
for id := range s.effects.RunningModels() {
|
||||
if id != target {
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// Serial methods all run on the router's single run-loop goroutine, so these
|
||||
// tests drive them directly and synchronously, reusing fakeEffects and the
|
||||
// req/reqCh helpers from fifo_test.go. A load completes via OnSwapDone and a
|
||||
// served request finishes via OnServeDone — the events the run loop delivers.
|
||||
|
||||
func newSerial(eff Effects) *Serial {
|
||||
return NewSerial("test", logmon.NewWriter(io.Discard), eff)
|
||||
}
|
||||
|
||||
// lastStart returns the most recent StartSwap record.
|
||||
func lastStart(t *testing.T, eff *fakeEffects) startRec {
|
||||
t.Helper()
|
||||
if len(eff.starts) == 0 {
|
||||
t.Fatal("no StartSwap recorded")
|
||||
}
|
||||
return eff.starts[len(eff.starts)-1]
|
||||
}
|
||||
|
||||
func sameSet(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
m := map[string]int{}
|
||||
for _, x := range a {
|
||||
m[x]++
|
||||
}
|
||||
for _, x := range b {
|
||||
m[x]--
|
||||
}
|
||||
for _, v := range m {
|
||||
if v != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// servedOrder returns the model IDs of every successful serve grant in order.
|
||||
func servedOrder(eff *fakeEffects) []string {
|
||||
var out []string
|
||||
for _, g := range eff.grants {
|
||||
if g.err == nil && g.serve {
|
||||
out = append(out, g.model)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestSerial_FastPath_AlreadyLoaded(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := len(eff.starts); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0 (already loaded, no swap)", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_ColdStart_LoadsThenServes(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 before load completes", got)
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 after load", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_UnknownModel(t *testing.T) {
|
||||
eff := newFakeEffects() // no states => unknown
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("ghost"))
|
||||
|
||||
if len(eff.starts) != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0", len(eff.starts))
|
||||
}
|
||||
if eff.errored("ghost") != 1 {
|
||||
t.Fatalf("errored(ghost)=%d want 1", eff.errored("ghost"))
|
||||
}
|
||||
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_EvictsEveryOtherModel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["x"] = process.StateReady // already running
|
||||
eff.states["y"] = process.StateReady // also running (e.g. left over)
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
st := lastStart(t, eff)
|
||||
if st.model != "a" {
|
||||
t.Fatalf("loading %s want a", st.model)
|
||||
}
|
||||
if !sameSet(st.evict, []string{"x", "y"}) {
|
||||
t.Errorf("evict=%v want [x y] (serial evicts ALL other models)", st.evict)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_OneJobAtATime verifies a second request waits while the first is
|
||||
// serving, and only starts after the first finishes.
|
||||
func TestSerial_OneJobAtATime(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // served immediately
|
||||
s.OnRequest(req("b")) // must wait — a is serving
|
||||
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("StartSwap(b)=%d want 0 while a is serving", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// a finishes -> b may now load (evicting a).
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a finished", got)
|
||||
}
|
||||
if st := lastStart(t, eff); !sameSet(st.evict, []string{"a"}) {
|
||||
t.Errorf("b evict=%v want [a]", st.evict)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_SameModelConsecutive_NoReload verifies back-to-back requests for the
|
||||
// already-loaded model run without a reload, one after another.
|
||||
func TestSerial_SameModelConsecutive_NoReload(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // cold load
|
||||
s.OnRequest(req("a")) // queued behind the first
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // first serves
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1 (one at a time)", got)
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // first done -> second serves
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2", got)
|
||||
}
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Errorf("StartSwap(a)=%d want 1 (second request must not reload)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_StrictArrivalOrder is the core guarantee: qwen36, qwen35, sdxl,
|
||||
// qwen36 execute in EXACTLY that order with evictions between each model switch,
|
||||
// including reloading qwen36 at the end even though it ran first.
|
||||
func TestSerial_StrictArrivalOrder(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, m := range []string{"qwen36", "qwen35", "sdxl"} {
|
||||
eff.states[m] = process.StateStopped
|
||||
}
|
||||
s := newSerial(eff)
|
||||
|
||||
for _, m := range []string{"qwen36", "qwen35", "sdxl", "qwen36"} {
|
||||
s.OnRequest(req(m))
|
||||
}
|
||||
|
||||
// Only the first job starts loading; the rest wait their turn.
|
||||
if len(eff.starts) != 1 || eff.starts[0].model != "qwen36" {
|
||||
t.Fatalf("starts=%+v want only [qwen36] loading first", eff.starts)
|
||||
}
|
||||
|
||||
// step completes the current model's load+serve and returns control to the
|
||||
// scheduler, which must start the next queued model.
|
||||
step := func(model string, wantEvict []string) {
|
||||
t.Helper()
|
||||
st := lastStart(t, eff)
|
||||
if st.model != model {
|
||||
t.Fatalf("loading %q want %q", st.model, model)
|
||||
}
|
||||
if !sameSet(st.evict, wantEvict) {
|
||||
t.Fatalf("loading %q evict=%v want %v", model, st.evict, wantEvict)
|
||||
}
|
||||
// Simulate the eviction + load actually happening.
|
||||
for _, e := range st.evict {
|
||||
eff.states[e] = process.StateStopped
|
||||
}
|
||||
eff.states[model] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: model})
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||
}
|
||||
|
||||
step("qwen36", nil) // cold load, nothing else running
|
||||
step("qwen35", []string{"qwen36"}) // evict qwen36
|
||||
step("sdxl", []string{"qwen35"}) // evict qwen35
|
||||
step("qwen36", []string{"sdxl"}) // RELOAD qwen36, evict sdxl
|
||||
|
||||
want := []string{"qwen36", "qwen35", "sdxl", "qwen36"}
|
||||
if got := servedOrder(eff); !sameOrder(got, want) {
|
||||
t.Fatalf("serve order=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func sameOrder(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestSerial_SwapError_FailsCallerAndAdvances(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b")) // queued behind a
|
||||
|
||||
// a's load fails: its caller is errored and b proceeds.
|
||||
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||
if eff.errored("a") != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1", eff.errored("a"))
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a's load failed", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_GrantServeFalse_Advances verifies that when the active request's
|
||||
// caller has disconnected by serve time, the queue advances to the next request.
|
||||
func TestSerial_GrantServeFalse_Advances(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.serveResult["a"] = false // a's caller is gone by grant time
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails -> advance to b
|
||||
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 (caller gone)", got)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 (advanced after gone caller)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnCancel_QueuedRequest(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(reqCh("a")) // starts loading a
|
||||
cancelled := reqCh("b")
|
||||
s.OnRequest(cancelled) // queued behind a
|
||||
if len(s.queued) != 1 {
|
||||
t.Fatalf("queued=%d want 1", len(s.queued))
|
||||
}
|
||||
|
||||
s.OnCancel(cancelled)
|
||||
if len(s.queued) != 0 {
|
||||
t.Fatalf("queued=%d want 0 after cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// a completes; b is gone, so nothing starts for it.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (cancelled before its turn)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnShutdown_FailsQueuedAndActiveLoad(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["c"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // active (loading)
|
||||
s.OnRequest(req("b")) // queued
|
||||
s.OnRequest(req("c")) // queued
|
||||
|
||||
s.OnShutdown(errors.New("shutting down"))
|
||||
|
||||
if got := eff.errored(""); got != 3 {
|
||||
t.Errorf("error grants=%d want 3 (active load + 2 queued)", got)
|
||||
}
|
||||
if len(s.queued) != 0 {
|
||||
t.Errorf("queued=%d want 0 after shutdown", len(s.queued))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_OnUnload_WhileServing verifies that unloading the model that is
|
||||
// actively serving does not strand the queue: OnUnload stops the process but
|
||||
// leaves the active request to end via OnServeDone, which then advances.
|
||||
func TestSerial_OnUnload_WhileServing(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // served immediately (a ready)
|
||||
s.OnRequest(req("b")) // queued behind a
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// Unload a while it is serving: the process is stopped, but the queue must
|
||||
// not advance yet — the active serve is still outstanding.
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("StartSwap(b)=%d want 0 before the serving request ends", got)
|
||||
}
|
||||
|
||||
// The killed handler returns -> OnServeDone advances to b.
|
||||
eff.states["a"] = process.StateStopped
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after the serving request ended", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnUnload_DropsQueuedAndStops(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // active (loading a)
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
// Unload a: its active load is failed and a is stopped.
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
|
||||
if eff.errored("a") != 1 {
|
||||
t.Errorf("errored(a)=%d want 1 (active load failed)", eff.errored("a"))
|
||||
}
|
||||
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
// b was queued and not unloaded; with a's load cancelled it now starts.
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Errorf("StartSwap(b)=%d want 1 after unload advanced the queue", got)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/event"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
@@ -340,6 +342,28 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
||||
// Pin the resolved model so the router skips body/query extraction.
|
||||
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||
|
||||
// If the path matches an upstream.ignorePaths entry and the model is
|
||||
// not already loaded, refuse the request without triggering a swap. The
|
||||
// server was not able to process the response because the model was not
|
||||
// already loaded.
|
||||
for _, re := range s.cfg.Upstream.IgnorePaths {
|
||||
if !re.MatchString(remainingPath) {
|
||||
continue
|
||||
}
|
||||
if s.local.Handles(modelID) {
|
||||
state, ok := s.local.RunningModels()[modelID]
|
||||
if !ok || state != process.StateReady {
|
||||
shared.SendResponse(w, r, http.StatusConflict,
|
||||
fmt.Sprintf("model %s is not loaded; path matches upstream.ignorePaths", modelID))
|
||||
return
|
||||
}
|
||||
}
|
||||
// Either the model is already loaded (no swap would be triggered)
|
||||
// or this is a peer model (peer proxying never swaps). Fall through
|
||||
// to normal dispatch.
|
||||
break
|
||||
}
|
||||
|
||||
switch {
|
||||
case s.local.Handles(modelID):
|
||||
s.local.ServeHTTP(w, r)
|
||||
|
||||
@@ -5,11 +5,13 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/config"
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||
)
|
||||
|
||||
@@ -156,6 +158,91 @@ func upstreamMetricsServer(response string) *Server {
|
||||
return s
|
||||
}
|
||||
|
||||
func TestServer_HandleUpstream_IgnorePaths(t *testing.T) {
|
||||
// Compile a pattern that matches static asset suffixes.
|
||||
pattern := regexp.MustCompile(`.*\.(js|json|css|png|gif|jpg|jpeg|txt)$`)
|
||||
|
||||
t.Run("matched path, model not loaded, returns 409", func(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||
// running is nil/empty: model is not in RunningModels() => not loaded.
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{"m1": {}},
|
||||
Upstream: config.UpstreamConfig{
|
||||
IgnorePaths: []*regexp.Regexp{pattern},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||
|
||||
if w.Code != http.StatusConflict {
|
||||
t.Fatalf("status = %d, want %d (body=%q)", w.Code, http.StatusConflict, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not loaded") {
|
||||
t.Errorf("body = %q, want it to contain 'not loaded'", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("matched path, model already loaded, serves normally", func(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||
local.running = map[string]process.ProcessState{"m1": process.StateReady}
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{"m1": {}},
|
||||
Upstream: config.UpstreamConfig{
|
||||
IgnorePaths: []*regexp.Regexp{pattern},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-matched path, model not loaded, serves normally", func(t *testing.T) {
|
||||
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||
s := newTestServer(local, newStubRouter(nil, ""))
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{"m1": {}},
|
||||
Upstream: config.UpstreamConfig{
|
||||
IgnorePaths: []*regexp.Regexp{pattern},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("matched path, peer model, serves normally", func(t *testing.T) {
|
||||
// Peer routers do not appear via RunningModels on the local router;
|
||||
// they should fall through to normal dispatch without 409.
|
||||
local := newStubRouter(nil, "")
|
||||
peer := newStubRouter([]string{"m1"}, "peer-body")
|
||||
s := newTestServer(local, peer)
|
||||
s.cfg = config.Config{
|
||||
Models: map[string]config.ModelConfig{"m1": {}},
|
||||
Upstream: config.UpstreamConfig{
|
||||
IgnorePaths: []*regexp.Regexp{pattern},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||
|
||||
if w.Code != http.StatusOK || w.Body.String() != "peer-body" {
|
||||
t.Fatalf("status=%d body=%q, want 200 'peer-body'", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
|
||||
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
|
||||
s := upstreamMetricsServer(resp)
|
||||
|
||||
@@ -105,7 +105,9 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
if s.perf == nil {
|
||||
shared.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
json.NewEncoder(w).Encode(map[string]bool{"enabled": false})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -136,6 +138,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"enabled": true,
|
||||
"sys_stats": sysStats,
|
||||
"gpu_stats": gpuStats,
|
||||
})
|
||||
|
||||
+115
-24
@@ -25,6 +25,8 @@ import (
|
||||
// TokenMetrics holds token usage and performance metrics.
|
||||
type TokenMetrics struct {
|
||||
CachedTokens int `json:"cache_tokens"`
|
||||
DraftTokens int `json:"draft_tokens"`
|
||||
DraftAccTokens int `json:"draft_acc_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
@@ -42,6 +44,7 @@ type ActivityLogEntry struct {
|
||||
Tokens TokenMetrics `json:"tokens"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
HasCapture bool `json:"has_capture"`
|
||||
ErrorMsg string `json:"error_msg,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
@@ -123,9 +126,11 @@ func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
// record parses a completed response body and stores/emits an activity entry.
|
||||
// When captures are enabled, a zstd+CBOR capture is stored for successful
|
||||
// requests, with cf controlling which request/response parts are retained.
|
||||
// reqBody and reqHeaders are the request data buffered before dispatch.
|
||||
// Successful requests store a zstd+CBOR capture (when enabled) with cf
|
||||
// controlling which parts are retained. Failed (non-200) requests capture the
|
||||
// request only and set ErrorMsg to a description of the failure, so the error
|
||||
// can be inspected without storing unreadable raw response bytes. reqBody and
|
||||
// reqHeaders are the request data buffered before dispatch.
|
||||
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
||||
tm := ActivityLogEntry{
|
||||
Timestamp: time.Now(),
|
||||
@@ -150,7 +155,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
||||
|
||||
if recorder.Status() != http.StatusOK {
|
||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
||||
queueAndEmit()
|
||||
decoded, decErr := mp.decodeResponseBody(recorder, r.URL.Path)
|
||||
tm.ErrorMsg = failedErrorMessage(recorder.Status(), decoded, decErr)
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
// Capture the request only; the failure is surfaced via ErrorMsg
|
||||
// rather than storing the (possibly undisplayable) response body.
|
||||
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf&^captureRespBody, reqBody, reqHeaders, nil)
|
||||
mp.emitMetric(tm)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -165,6 +176,7 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
||||
decoded, err := decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||
tm.ErrorMsg = fmt.Sprintf("response decompression failed: %v", err)
|
||||
queueAndEmit()
|
||||
return
|
||||
}
|
||||
@@ -203,28 +215,99 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
||||
}
|
||||
|
||||
tm.ID = mp.queueMetrics(tm)
|
||||
if mp.enableCaptures {
|
||||
capture := ReqRespCapture{
|
||||
ID: tm.ID,
|
||||
ReqPath: r.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
}
|
||||
if cf&captureReqBody != 0 {
|
||||
capture.ReqBody = reqBody
|
||||
}
|
||||
if cf&captureRespHeaders != 0 {
|
||||
capture.RespHeaders = headerMap(recorder.Header())
|
||||
redactHeaders(capture.RespHeaders)
|
||||
delete(capture.RespHeaders, "Content-Encoding")
|
||||
}
|
||||
if cf&captureRespBody != 0 {
|
||||
capture.RespBody = body
|
||||
}
|
||||
if mp.addCapture(capture) {
|
||||
tm.HasCapture = true
|
||||
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf, reqBody, reqHeaders, body)
|
||||
mp.emitMetric(tm)
|
||||
}
|
||||
|
||||
// storeCapture assembles a ReqRespCapture for id, honoring the captureFields
|
||||
// mask, and stores it when captures are enabled. body is the response body to
|
||||
// capture (already decompressed by the caller); pass nil to omit it. Returns
|
||||
// true if a capture was stored.
|
||||
func (mp *metricsMonitor) storeCapture(id int, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string, body []byte) bool {
|
||||
if !mp.enableCaptures {
|
||||
return false
|
||||
}
|
||||
capture := ReqRespCapture{
|
||||
ID: id,
|
||||
ReqPath: r.URL.Path,
|
||||
ReqHeaders: reqHeaders,
|
||||
}
|
||||
if cf&captureReqBody != 0 {
|
||||
capture.ReqBody = reqBody
|
||||
}
|
||||
if cf&captureRespHeaders != 0 {
|
||||
capture.RespHeaders = headerMap(recorder.Header())
|
||||
redactHeaders(capture.RespHeaders)
|
||||
delete(capture.RespHeaders, "Content-Encoding")
|
||||
}
|
||||
if cf&captureRespBody != 0 {
|
||||
capture.RespBody = body
|
||||
}
|
||||
return mp.addCapture(capture)
|
||||
}
|
||||
|
||||
// decodeResponseBody returns the buffered response body, decompressing it when
|
||||
// the upstream set a Content-Encoding we recognize. On decompression failure it
|
||||
// logs a warning and returns an error so the caller can record a description
|
||||
// (via ErrorMsg) instead of storing unreadable raw bytes.
|
||||
func (mp *metricsMonitor) decodeResponseBody(recorder *responseBodyCopier, path string) ([]byte, error) {
|
||||
body := recorder.body.Bytes()
|
||||
if len(body) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
encoding := recorder.Header().Get("Content-Encoding")
|
||||
if encoding == "" {
|
||||
return body, nil
|
||||
}
|
||||
decoded, err := decompressBody(body, encoding)
|
||||
if err != nil {
|
||||
mp.logger.Warnf("metrics: response decompression failed: %v, path=%s", err, path)
|
||||
return nil, err
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// errorMessagePaths lists JSON paths where a human-readable error message can
|
||||
// live across OpenAI- and llama.cpp-style error responses.
|
||||
var errorMessagePaths = []string{"error.message", "error", "message", "detail"}
|
||||
|
||||
// extractErrorMessage pulls a human-readable error string from a JSON error
|
||||
// response. Returns "" if no message is found or the body is not valid JSON.
|
||||
func extractErrorMessage(body []byte) string {
|
||||
if !gjson.ValidBytes(body) {
|
||||
return ""
|
||||
}
|
||||
parsed := gjson.ParseBytes(body)
|
||||
for _, path := range errorMessagePaths {
|
||||
v := parsed.Get(path)
|
||||
if v.Exists() && v.Type == gjson.String {
|
||||
if s := strings.TrimSpace(v.String()); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
mp.emitMetric(tm)
|
||||
return ""
|
||||
}
|
||||
|
||||
// failedErrorMessage builds a human-readable description for a non-200 response.
|
||||
// It prefers an error message parsed from the (decompressed) body and falls back
|
||||
// to the HTTP status text. A non-nil decErr indicates the body could not be
|
||||
// decoded, in which case the decode error is described instead.
|
||||
func failedErrorMessage(status int, body []byte, decErr error) string {
|
||||
const maxLen = 500
|
||||
if decErr != nil {
|
||||
return fmt.Sprintf("response decode failed: %v", decErr)
|
||||
}
|
||||
if msg := extractErrorMessage(body); msg != "" {
|
||||
if len(msg) > maxLen {
|
||||
msg = msg[:maxLen] + "..."
|
||||
}
|
||||
return msg
|
||||
}
|
||||
if text := http.StatusText(status); text != "" {
|
||||
return fmt.Sprintf("%d %s", status, text)
|
||||
}
|
||||
return fmt.Sprintf("HTTP %d", status)
|
||||
}
|
||||
|
||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||
@@ -345,6 +428,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
||||
durationMs := wallDurationMs
|
||||
tokensPerSecond := -1.0
|
||||
promptPerSecond := -1.0
|
||||
draftTokens := -1
|
||||
draftAccTokens := -1
|
||||
|
||||
if timings.Exists() {
|
||||
inputTokens = timings.Get("prompt_n").Int()
|
||||
@@ -358,6 +443,10 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||
cachedTokens = cachedValue.Int()
|
||||
}
|
||||
if timings.Get("draft_n").Exists() && timings.Get("draft_n_accepted").Exists() {
|
||||
draftTokens = int(timings.Get("draft_n").Int())
|
||||
draftAccTokens = int(timings.Get("draft_n_accepted").Int())
|
||||
}
|
||||
}
|
||||
|
||||
return ActivityLogEntry{
|
||||
@@ -365,6 +454,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
||||
Model: modelID,
|
||||
Tokens: TokenMetrics{
|
||||
CachedTokens: int(cachedTokens),
|
||||
DraftTokens: draftTokens,
|
||||
DraftAccTokens: draftAccTokens,
|
||||
InputTokens: int(inputTokens),
|
||||
OutputTokens: int(outputTokens),
|
||||
PromptPerSecond: promptPerSecond,
|
||||
|
||||
@@ -90,6 +90,172 @@ func TestMetricsMonitor_RecordMetadata(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordFailedRequestCapture(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
reqHeaders := map[string]string{"content-type": "application/json"}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.Header().Set("Content-Type", "application/json")
|
||||
copier.WriteHeader(http.StatusBadGateway)
|
||||
copier.Write([]byte(`{"error":{"message":"model unavailable"}}`))
|
||||
|
||||
reqBody := []byte(`{"model":"m","messages":[]}`)
|
||||
mm.record("m", r, copier, captureAll, reqBody, reqHeaders)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
entry := entries[0]
|
||||
if entry.RespStatusCode != http.StatusBadGateway {
|
||||
t.Errorf("status = %d, want %d", entry.RespStatusCode, http.StatusBadGateway)
|
||||
}
|
||||
if entry.ErrorMsg != "model unavailable" {
|
||||
t.Errorf("error_msg = %q, want extracted message", entry.ErrorMsg)
|
||||
}
|
||||
if !entry.HasCapture {
|
||||
t.Fatal("failed request should capture the request so it can be inspected")
|
||||
}
|
||||
|
||||
got := mm.getCaptureByID(entry.ID)
|
||||
if got == nil {
|
||||
t.Fatal("capture not found")
|
||||
}
|
||||
if string(got.ReqBody) != `{"model":"m","messages":[]}` {
|
||||
t.Errorf("req body = %q", got.ReqBody)
|
||||
}
|
||||
if len(got.RespBody) != 0 {
|
||||
t.Errorf("resp body stored for failed request (len=%d); want none", len(got.RespBody))
|
||||
}
|
||||
if got.RespHeaders["Content-Type"] != "application/json" {
|
||||
t.Errorf("resp Content-Type = %q", got.RespHeaders["Content-Type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordFailedRequestStatusFallback(t *testing.T) {
|
||||
// Non-JSON error body: ErrorMsg falls back to the HTTP status text.
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.WriteHeader(http.StatusBadGateway)
|
||||
copier.Write([]byte("<html>upstream down</html>"))
|
||||
|
||||
mm.record("m", r, copier, captureAll, nil, nil)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
if entries[0].ErrorMsg != "502 Bad Gateway" {
|
||||
t.Errorf("error_msg = %q, want status text", entries[0].ErrorMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordFailedRequestCaptureDisabled(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 0) // captures disabled
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.WriteHeader(http.StatusInternalServerError)
|
||||
copier.Write([]byte(`{"error":"boom"}`))
|
||||
|
||||
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
if entries[0].HasCapture {
|
||||
t.Fatal("captures disabled, HasCapture should be false")
|
||||
}
|
||||
// ErrorMsg is independent of whether captures are enabled.
|
||||
if entries[0].ErrorMsg != "boom" {
|
||||
t.Errorf("error_msg = %q, want boom", entries[0].ErrorMsg)
|
||||
}
|
||||
if mm.getCaptureByID(entries[0].ID) != nil {
|
||||
t.Fatal("no capture should be stored when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_RecordDecompressionFailureSetsErrorMsg(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.Header().Set("Content-Encoding", "gzip")
|
||||
copier.WriteHeader(http.StatusOK)
|
||||
copier.Write([]byte("not-really-gzip"))
|
||||
|
||||
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||
|
||||
entries := mm.getMetrics()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||
}
|
||||
if entries[0].ErrorMsg == "" {
|
||||
t.Fatal("expected ErrorMsg for decompression failure")
|
||||
}
|
||||
// Raw bytes must not be stored when the body could not be decoded.
|
||||
if entries[0].HasCapture {
|
||||
t.Fatal("decompression failure should not store a capture")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMonitor_DecodeResponseBody(t *testing.T) {
|
||||
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||
|
||||
// No Content-Encoding: body returned unchanged.
|
||||
w := httptest.NewRecorder()
|
||||
copier := newBodyCopier(w)
|
||||
copier.Write([]byte("plain"))
|
||||
got, err := mm.decodeResponseBody(copier, "/p")
|
||||
if err != nil || string(got) != "plain" {
|
||||
t.Fatalf("plain body = %q, err = %v", got, err)
|
||||
}
|
||||
|
||||
// Bogus gzip payload: returns an error and no body (no raw bytes kept).
|
||||
w2 := httptest.NewRecorder()
|
||||
copier2 := newBodyCopier(w2)
|
||||
copier2.Header().Set("Content-Encoding", "gzip")
|
||||
copier2.Write([]byte("not-really-gzip"))
|
||||
got, err = mm.decodeResponseBody(copier2, "/p")
|
||||
if err == nil {
|
||||
t.Fatal("expected decompression error")
|
||||
}
|
||||
if got != nil {
|
||||
t.Errorf("expected nil body on failure, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ExtractErrorMessage(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
body string
|
||||
want string
|
||||
}{
|
||||
{"openai object", `{"error":{"message":"rate limited"}}`, "rate limited"},
|
||||
{"string error", `{"error":"bad request"}`, "bad request"},
|
||||
{"message field", `{"message":"nope"}`, "nope"},
|
||||
{"detail field", `{"detail":"oops"}`, "oops"},
|
||||
{"object error ignored", `{"error":{"code":42}}`, ""},
|
||||
{"no error", `{"usage":{}}`, ""},
|
||||
{"invalid json", `not-json`, ""},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := extractErrorMessage([]byte(tc.body)); got != tc.want {
|
||||
t.Errorf("extractErrorMessage = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||
// /infill responses are arrays; timings live in the last element.
|
||||
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DirWatcher polls a directory for changes to its set of *.yml / *.yaml files.
|
||||
// It fires OnChange when a file is added, removed, or has its mod time/size
|
||||
// change. Like Watcher it is poll-based so it works in Docker bind-mounts and
|
||||
// k8s ConfigMap projections where inotify is unreliable.
|
||||
//
|
||||
// The baseline poll establishes initial state and does not fire OnChange.
|
||||
type DirWatcher struct {
|
||||
Path string
|
||||
Interval time.Duration
|
||||
OnChange func()
|
||||
}
|
||||
|
||||
// dirSnapshot is an ordered map of file name -> file state. The ordering is
|
||||
// derived from sorted filenames so two snapshots compare deterministically
|
||||
// regardless of readdir order. exists reflects whether the directory was
|
||||
// readable at scan time; a missing directory yields exists=false.
|
||||
type dirSnapshot struct {
|
||||
exists bool
|
||||
names []string
|
||||
states map[string]snapshot
|
||||
}
|
||||
|
||||
func newDirSnapshot() dirSnapshot {
|
||||
return dirSnapshot{states: make(map[string]snapshot)}
|
||||
}
|
||||
|
||||
// equal reports whether two snapshots describe the same file set and per-file
|
||||
// state. A missing directory (exists=false) is treated as equal to any other
|
||||
// missing directory regardless of cached names.
|
||||
func (s dirSnapshot) equal(other dirSnapshot) bool {
|
||||
if !s.exists && !other.exists {
|
||||
return true
|
||||
}
|
||||
if s.exists != other.exists {
|
||||
return false
|
||||
}
|
||||
if len(s.names) != len(other.names) {
|
||||
return false
|
||||
}
|
||||
for i, n := range s.names {
|
||||
if other.names[i] != n {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, n := range s.names {
|
||||
a, b := s.states[n], other.states[n]
|
||||
if a.exists != b.exists || a.size != b.size || !a.modTime.Equal(b.modTime) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||
// OnChange whenever the directory's YAML file set changes.
|
||||
//
|
||||
// Policy mirrors the single-file Watcher: disappearance (directory missing or
|
||||
// empty) is treated as a transient rename-style write and stays quiet; the
|
||||
// transition back to present-with-content fires OnChange.
|
||||
func (w *DirWatcher) Run(ctx context.Context) {
|
||||
interval := w.Interval
|
||||
if interval <= 0 {
|
||||
interval = DefaultInterval
|
||||
}
|
||||
|
||||
prev := scanDir(w.Path)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cur := scanDir(w.Path)
|
||||
// Suppress transitions involving an empty or missing directory —
|
||||
// these are treated as transient rename-style writes, mirroring
|
||||
// the single-file Watcher. Only present-with-content →
|
||||
// present-with-content (changed) or no-content →
|
||||
// present-with-content fires OnChange.
|
||||
prevHasContent := prev.exists && len(prev.names) > 0
|
||||
curHasContent := cur.exists && len(cur.names) > 0
|
||||
if curHasContent && (!prevHasContent || !prev.equal(cur)) && w.OnChange != nil {
|
||||
w.OnChange()
|
||||
}
|
||||
prev = cur
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scanDir returns a snapshot of the *.yml/*.yaml files in dir. If the
|
||||
// directory cannot be read (missing, permission denied) the snapshot reports
|
||||
// exists=false; the next successful scan will detect the recovery and fire
|
||||
// OnChange.
|
||||
func scanDir(dir string) dirSnapshot {
|
||||
snap := newDirSnapshot()
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return snap // exists=false
|
||||
}
|
||||
snap.exists = true
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||
continue
|
||||
}
|
||||
fi, err := os.Stat(filepath.Join(dir, name))
|
||||
if err != nil {
|
||||
// File disappeared between ReadDir and Stat; skip it — the
|
||||
// next poll will observe the removal cleanly.
|
||||
continue
|
||||
}
|
||||
snap.names = append(snap.names, name)
|
||||
snap.states[name] = snapshot{
|
||||
exists: true,
|
||||
modTime: fi.ModTime(),
|
||||
size: fi.Size(),
|
||||
}
|
||||
}
|
||||
sort.Strings(snap.names)
|
||||
return snap
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// startDirWatcher launches w.Run in a goroutine and returns a function that
|
||||
// cancels the context and waits for Run to return.
|
||||
func startDirWatcher(t *testing.T, w *DirWatcher) func() {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("DirWatcher did not stop within 2s of cancel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeYAMLInDir(t *testing.T, dir, name, content string) {
|
||||
t.Helper()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644))
|
||||
}
|
||||
|
||||
func TestDirWatcher_NoFireOnBaseline(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
|
||||
time.Sleep(testInterval * 5)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsFileAdd(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is added")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsFileRemoval(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "b.yaml")))
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is removed")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsModTimeChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base, base))
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||
}
|
||||
|
||||
func TestDirWatcher_IgnoresNonYAMLFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Adding a .txt file must not fire.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0o644))
|
||||
time.Sleep(testInterval * 4)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "non-YAML files must be ignored")
|
||||
|
||||
// Adding a .yml file must fire.
|
||||
writeYAMLInDir(t, dir, "b.yml", "b")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire for *.yml files")
|
||||
}
|
||||
|
||||
func TestDirWatcher_MissingDirRecovers(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Remove the directory. No fire expected on disappearance alone.
|
||||
require.NoError(t, os.RemoveAll(dir))
|
||||
time.Sleep(testInterval * 3)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "directory removal alone must not fire")
|
||||
|
||||
// Recreate the directory and a YAML file; the recovery should fire.
|
||||
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||
writeYAMLInDir(t, dir, "recovered.yaml", "r")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when dir returns with content")
|
||||
}
|
||||
|
||||
func TestDirWatcher_EmptyDirSuppressedThenRecovers(t *testing.T) {
|
||||
// Present-with-content → empty (all YAML removed, dir still exists)
|
||||
// must stay quiet — treated as transient per the documented policy.
|
||||
// The transition back to content fires.
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Remove the only YAML file. Dir still exists but is empty of YAML.
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "a.yaml")))
|
||||
time.Sleep(testInterval * 4)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "emptying the directory must not fire")
|
||||
|
||||
// Add a YAML file back; transition to present-with-content fires.
|
||||
writeYAMLInDir(t, dir, "c.yaml", "c")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when content returns")
|
||||
}
|
||||
|
||||
func TestDirWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
w := &DirWatcher{Path: dir, Interval: testInterval}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() { w.Run(ctx); close(done) }()
|
||||
|
||||
time.Sleep(testInterval * 2)
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run did not return within 2s of cancel")
|
||||
}
|
||||
}
|
||||
+37
-19
@@ -55,7 +55,8 @@ var logTimeFormats = map[string]string{
|
||||
}
|
||||
|
||||
func main() {
|
||||
flagConfig := flag.String("config", "", "path to config file (required)")
|
||||
flagConfig := flag.String("config", "", "path to config file")
|
||||
flagConfigDir := flag.String("config-dir", "", "directory of *.yml/*.yaml config files (additive to -config)")
|
||||
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
||||
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
@@ -68,8 +69,8 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if *flagConfig == "" {
|
||||
slog.Error("-config is required")
|
||||
if *flagConfig == "" && *flagConfigDir == "" {
|
||||
slog.Error("at least one of -config or -config-dir must be provided")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -88,10 +89,9 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
configPath := *flagConfig
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
cfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||
if err != nil {
|
||||
slog.Error("failed to load config", "path", configPath, "error", err)
|
||||
slog.Error("failed to load config", "config", *flagConfig, "config-dir", *flagConfigDir, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -187,7 +187,7 @@ func main() {
|
||||
|
||||
proxyLog.Info("reloading configuration")
|
||||
|
||||
newCfg, err := config.LoadConfig(configPath)
|
||||
newCfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||
if err != nil {
|
||||
proxyLog.Warnf("failed to reload config: %v", err)
|
||||
return
|
||||
@@ -230,19 +230,37 @@ func main() {
|
||||
defer watcherCancel()
|
||||
|
||||
if *flagWatchConfig {
|
||||
absConfigPath, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
slog.Error("watch-config: failed to resolve config path", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
||||
go func() {
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: reload,
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
|
||||
if *flagConfig != "" {
|
||||
absConfigPath, err := filepath.Abs(*flagConfig)
|
||||
if err != nil {
|
||||
slog.Error("watch-config: failed to resolve config path", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
go func() {
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: reload,
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
|
||||
if *flagConfigDir != "" {
|
||||
absConfigDir, err := filepath.Abs(*flagConfigDir)
|
||||
if err != nil {
|
||||
slog.Error("watch-config: failed to resolve config-dir path", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
go func() {
|
||||
(&configwatcher.DirWatcher{
|
||||
Path: absConfigDir,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: reload,
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
import Performance from "./routes/Performance.svelte";
|
||||
import Playground from "./routes/Playground.svelte";
|
||||
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
||||
import { enableAPIEvents } from "./stores/api";
|
||||
import { enableAPIEvents, checkPerformanceEnabled } from "./stores/api";
|
||||
import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
||||
import { currentRoute } from "./stores/route";
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
const cleanupScreenWidth = initScreenWidth();
|
||||
const cleanupSystemTheme = initSystemThemeListener();
|
||||
enableAPIEvents(true);
|
||||
checkPerformanceEnabled();
|
||||
|
||||
return () => {
|
||||
cleanupScreenWidth();
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme";
|
||||
import { currentRoute } from "../stores/route";
|
||||
import { playgroundActivity } from "../stores/playgroundActivity";
|
||||
import { performanceEnabled } from "../stores/api";
|
||||
import ConnectionStatus from "./ConnectionStatus.svelte";
|
||||
|
||||
function handleTitleChange(newTitle: string): void {
|
||||
@@ -84,16 +85,18 @@
|
||||
>
|
||||
Logs
|
||||
</a>
|
||||
<a
|
||||
href="/performance"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/performance", $currentRoute)}
|
||||
class:underline={isActive("/performance", $currentRoute)}
|
||||
class:underline-offset-4={isActive("/performance", $currentRoute)}
|
||||
>
|
||||
Performance
|
||||
</a>
|
||||
{#if $performanceEnabled}
|
||||
<a
|
||||
href="/performance"
|
||||
use:link
|
||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||
class:font-semibold={isActive("/performance", $currentRoute)}
|
||||
class:underline={isActive("/performance", $currentRoute)}
|
||||
class:underline-offset-4={isActive("/performance", $currentRoute)}
|
||||
>
|
||||
Performance
|
||||
</a>
|
||||
{/if}
|
||||
<button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})">
|
||||
{#if $themeMode === "system"}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||
|
||||
@@ -25,6 +25,8 @@ export interface Model {
|
||||
|
||||
export interface TokenMetrics {
|
||||
cache_tokens: number;
|
||||
draft_tokens: number;
|
||||
draft_acc_tokens: number;
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
prompt_per_second: number;
|
||||
@@ -41,6 +43,7 @@ export interface ActivityLogEntry {
|
||||
tokens: TokenMetrics;
|
||||
duration_ms: number;
|
||||
has_capture: boolean;
|
||||
error_msg?: string;
|
||||
metadata?: Record<string, string>;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,11 +21,12 @@
|
||||
{ key: "time", label: "Time", defaultVisible: true },
|
||||
{ key: "model", label: "Model", defaultVisible: true },
|
||||
{ key: "req_path", label: "Path", defaultVisible: false },
|
||||
{ key: "resp_status_code", label: "Status", defaultVisible: false },
|
||||
{ key: "resp_status_code", label: "Status", defaultVisible: true },
|
||||
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
||||
{ key: "cached", label: "Cached", defaultVisible: true },
|
||||
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
||||
{ key: "generated", label: "Generated", defaultVisible: true },
|
||||
{ key: "drafted", label: "Drafted", defaultVisible: false },
|
||||
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
||||
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
||||
{ key: "duration", label: "Duration", defaultVisible: true },
|
||||
@@ -158,6 +159,10 @@
|
||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||
}
|
||||
|
||||
function formatDrafted(drafted: number, accepted: number): string {
|
||||
return drafted > 0 ? (accepted * 100 / drafted).toFixed(1) + "% (" + accepted + "/" + drafted + ")" : "-";
|
||||
}
|
||||
|
||||
function formatDuration(ms: number): string {
|
||||
return (ms / 1000).toFixed(2) + "s";
|
||||
}
|
||||
@@ -273,6 +278,8 @@
|
||||
Cached <Tooltip content="prompt tokens from cache" />
|
||||
{:else if key === "prompt"}
|
||||
Prompt <Tooltip content="new prompt tokens processed" />
|
||||
{:else if key === "drafted"}
|
||||
Drafted <Tooltip content="acceptance rate (accepted/drafted)" />
|
||||
{:else}
|
||||
{columnLabelMap[key] ?? key}
|
||||
{/if}
|
||||
@@ -301,7 +308,13 @@
|
||||
{:else if key === "req_path"}
|
||||
{metric.req_path || "-"}
|
||||
{:else if key === "resp_status_code"}
|
||||
{metric.resp_status_code || "-"}
|
||||
{#if metric.error_msg}
|
||||
<span class="text-red-500 dark:text-red-400 cursor-help" title={metric.error_msg}>
|
||||
{metric.resp_status_code || "-"}
|
||||
</span>
|
||||
{:else}
|
||||
{metric.resp_status_code || "-"}
|
||||
{/if}
|
||||
{:else if key === "resp_content_type"}
|
||||
{metric.resp_content_type || "-"}
|
||||
{:else if key === "cached"}
|
||||
@@ -310,6 +323,8 @@
|
||||
{metric.tokens.input_tokens.toLocaleString()}
|
||||
{:else if key === "generated"}
|
||||
{metric.tokens.output_tokens.toLocaleString()}
|
||||
{:else if key === "drafted"}
|
||||
{formatDrafted(metric.tokens.draft_tokens, metric.tokens.draft_acc_tokens)}
|
||||
{:else if key === "prompt_speed"}
|
||||
{formatSpeed(metric.tokens.prompt_per_second)}
|
||||
{:else if key === "gen_speed"}
|
||||
|
||||
@@ -19,6 +19,7 @@ export const proxyLogs = writable<string>("");
|
||||
export const upstreamLogs = writable<string>("");
|
||||
export const metrics = writable<ActivityLogEntry[]>([]);
|
||||
export const inFlightRequests = writable<number>(0);
|
||||
export const performanceEnabled = writable<boolean>(false);
|
||||
export const versionInfo = writable<VersionInfo>({
|
||||
build_date: "unknown",
|
||||
commit: "unknown",
|
||||
@@ -210,6 +211,20 @@ export async function getCapture(id: number): Promise<ReqRespCapture | null> {
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkPerformanceEnabled(): Promise<void> {
|
||||
try {
|
||||
const response = await fetch("/api/performance");
|
||||
if (!response.ok) {
|
||||
performanceEnabled.set(false);
|
||||
return;
|
||||
}
|
||||
const data = await response.json();
|
||||
performanceEnabled.set(data.enabled);
|
||||
} catch {
|
||||
performanceEnabled.set(false);
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> {
|
||||
try {
|
||||
const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance";
|
||||
|
||||
Reference in New Issue
Block a user